feat: migrate auth system from SQLite to PostgreSQL
- auth.py: rewrite to use PG via db.py (was sqlite3) - admin_cli.py: rewrite to use PG - migrate_auth_sqlite_to_pg.py: one-time migration script - SQLite arb.db no longer needed after migration
This commit is contained in:
parent
4f54e36d1a
commit
8b73500d22
@ -1,82 +1,90 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""Admin CLI for Arbitrage Engine"""
|
"""Admin CLI for Arbitrage Engine (PostgreSQL version)"""
|
||||||
import sys, os, sqlite3, secrets, json
|
import sys, os, secrets
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
from db import get_sync_conn
|
||||||
|
|
||||||
def get_conn():
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
return conn
|
|
||||||
|
|
||||||
def gen_invite(count=1, max_uses=1):
|
def gen_invite(count=1, max_uses=1):
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
codes = []
|
with conn.cursor() as cur:
|
||||||
for _ in range(count):
|
codes = []
|
||||||
code = secrets.token_urlsafe(6)[:8].upper()
|
for _ in range(count):
|
||||||
conn.execute(
|
code = secrets.token_urlsafe(6)[:8].upper()
|
||||||
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)",
|
cur.execute(
|
||||||
(code, max_uses)
|
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, 1, %s)",
|
||||||
)
|
(code, max_uses)
|
||||||
codes.append(code)
|
)
|
||||||
conn.commit()
|
codes.append(code)
|
||||||
conn.close()
|
conn.commit()
|
||||||
for c in codes:
|
for c in codes:
|
||||||
print(f" {c}")
|
print(f" {c}")
|
||||||
print(f"\nGenerated {len(codes)} invite code(s)")
|
print(f"\nGenerated {len(codes)} invite code(s)")
|
||||||
|
|
||||||
|
|
||||||
def list_invites():
|
def list_invites():
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall()
|
with conn.cursor() as cur:
|
||||||
conn.close()
|
cur.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC")
|
||||||
|
cols = [desc[0] for desc in cur.description]
|
||||||
|
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||||
if not rows:
|
if not rows:
|
||||||
print("No invite codes found")
|
print("No invite codes found")
|
||||||
return
|
return
|
||||||
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
|
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
|
||||||
print("-" * 60)
|
print("-" * 60)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}")
|
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {str(r['created_at']):>20}")
|
||||||
|
|
||||||
|
|
||||||
def disable_invite(code):
|
def disable_invite(code):
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,))
|
with conn.cursor() as cur:
|
||||||
conn.commit()
|
cur.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = %s", (code,))
|
||||||
conn.close()
|
conn.commit()
|
||||||
print(f"Disabled invite code: {code}")
|
print(f"Disabled invite code: {code}")
|
||||||
|
|
||||||
|
|
||||||
def list_users():
|
def list_users():
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall()
|
with conn.cursor() as cur:
|
||||||
conn.close()
|
cur.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC")
|
||||||
|
cols = [desc[0] for desc in cur.description]
|
||||||
|
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||||
if not rows:
|
if not rows:
|
||||||
print("No users found")
|
print("No users found")
|
||||||
return
|
return
|
||||||
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
|
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
|
||||||
print("-" * 72)
|
print("-" * 72)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}")
|
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {str(r['created_at']):>20}")
|
||||||
|
|
||||||
|
|
||||||
def ban_user(user_id):
|
def ban_user(user_id):
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,))
|
with conn.cursor() as cur:
|
||||||
conn.commit()
|
cur.execute("UPDATE users SET banned = 1 WHERE id = %s", (user_id,))
|
||||||
conn.close()
|
conn.commit()
|
||||||
print(f"Banned user ID: {user_id}")
|
print(f"Banned user ID: {user_id}")
|
||||||
|
|
||||||
|
|
||||||
def unban_user(user_id):
|
def unban_user(user_id):
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,))
|
with conn.cursor() as cur:
|
||||||
conn.commit()
|
cur.execute("UPDATE users SET banned = 0 WHERE id = %s", (user_id,))
|
||||||
conn.close()
|
conn.commit()
|
||||||
print(f"Unbanned user ID: {user_id}")
|
print(f"Unbanned user ID: {user_id}")
|
||||||
|
|
||||||
|
|
||||||
def set_admin(user_id):
|
def set_admin(user_id):
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,))
|
with conn.cursor() as cur:
|
||||||
conn.commit()
|
cur.execute("UPDATE users SET role = 'admin' WHERE id = %s", (user_id,))
|
||||||
conn.close()
|
conn.commit()
|
||||||
print(f"Set user {user_id} as admin")
|
print(f"Set user {user_id} as admin")
|
||||||
|
|
||||||
|
|
||||||
def usage():
|
def usage():
|
||||||
print("""Usage: python3 admin_cli.py <command> [args]
|
print("""Usage: python3 admin_cli.py <command> [args]
|
||||||
|
|
||||||
|
|||||||
335
backend/auth.py
335
backend/auth.py
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import secrets
|
import secrets
|
||||||
import hmac
|
import hmac
|
||||||
@ -7,12 +6,12 @@ import base64
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Header, Depends, Request
|
from fastapi import APIRouter, HTTPException, Header, Depends, Request
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
|
from db import get_sync_conn
|
||||||
|
|
||||||
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
|
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
|
||||||
ACCESS_TOKEN_HOURS = 24
|
ACCESS_TOKEN_HOURS = 24
|
||||||
REFRESH_TOKEN_DAYS = 7
|
REFRESH_TOKEN_DAYS = 7
|
||||||
@ -20,90 +19,98 @@ REFRESH_TOKEN_DAYS = 7
|
|||||||
router = APIRouter(prefix="/api", tags=["auth"])
|
router = APIRouter(prefix="/api", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
# ─── DB helpers ───────────────────────────────────────────────
|
# ─── PG Schema ───────────────────────────────────────────────
|
||||||
|
|
||||||
def get_conn():
|
AUTH_SCHEMA = """
|
||||||
conn = sqlite3.connect(DB_PATH)
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
conn.row_factory = sqlite3.Row
|
id BIGSERIAL PRIMARY KEY,
|
||||||
return conn
|
email TEXT UNIQUE NOT NULL,
|
||||||
|
password_hash TEXT NOT NULL,
|
||||||
|
discord_id TEXT,
|
||||||
|
role TEXT NOT NULL DEFAULT 'user',
|
||||||
|
banned INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||||
|
user_id BIGINT PRIMARY KEY REFERENCES users(id),
|
||||||
|
tier TEXT NOT NULL DEFAULT 'free',
|
||||||
|
expires_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
code TEXT UNIQUE NOT NULL,
|
||||||
|
created_by INTEGER,
|
||||||
|
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||||
|
used_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
|
expires_at TEXT,
|
||||||
|
created_at TEXT DEFAULT (NOW()::TEXT)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_usage (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
code_id BIGINT NOT NULL REFERENCES invite_codes(id),
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id),
|
||||||
|
used_at TEXT DEFAULT (NOW()::TEXT)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id),
|
||||||
|
token TEXT UNIQUE NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
revoked INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT DEFAULT (NOW()::TEXT)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def ensure_tables():
|
def ensure_tables():
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
cur = conn.cursor()
|
with conn.cursor() as cur:
|
||||||
cur.execute("""
|
for stmt in AUTH_SCHEMA.split(";"):
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
stmt = stmt.strip()
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
if stmt:
|
||||||
email TEXT UNIQUE NOT NULL,
|
try:
|
||||||
password_hash TEXT NOT NULL,
|
cur.execute(stmt)
|
||||||
discord_id TEXT,
|
except Exception:
|
||||||
role TEXT NOT NULL DEFAULT 'user',
|
conn.rollback()
|
||||||
banned INTEGER NOT NULL DEFAULT 0,
|
continue
|
||||||
created_at TEXT NOT NULL
|
conn.commit()
|
||||||
)
|
|
||||||
""")
|
|
||||||
cur.execute("""
|
# ─── DB helper ───────────────────────────────────────────────
|
||||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
|
||||||
user_id INTEGER PRIMARY KEY,
|
def _fetchone(sql, params=None):
|
||||||
tier TEXT NOT NULL DEFAULT 'free',
|
with get_sync_conn() as conn:
|
||||||
expires_at TEXT,
|
with conn.cursor() as cur:
|
||||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
cur.execute(sql, params or ())
|
||||||
)
|
row = cur.fetchone()
|
||||||
""")
|
if not row:
|
||||||
cur.execute("""
|
return None
|
||||||
CREATE TABLE IF NOT EXISTS signal_logs (
|
cols = [desc[0] for desc in cur.description]
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
return dict(zip(cols, row))
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
rate REAL NOT NULL,
|
|
||||||
annualized REAL NOT NULL,
|
def _fetchall(sql, params=None):
|
||||||
sent_at TEXT NOT NULL,
|
with get_sync_conn() as conn:
|
||||||
message TEXT NOT NULL
|
with conn.cursor() as cur:
|
||||||
)
|
cur.execute(sql, params or ())
|
||||||
""")
|
cols = [desc[0] for desc in cur.description]
|
||||||
cur.execute("""
|
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||||
CREATE TABLE IF NOT EXISTS invite_codes (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
code TEXT UNIQUE NOT NULL,
|
def _execute(sql, params=None):
|
||||||
created_by INTEGER,
|
with get_sync_conn() as conn:
|
||||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
with conn.cursor() as cur:
|
||||||
used_count INTEGER NOT NULL DEFAULT 0,
|
cur.execute(sql, params or ())
|
||||||
status TEXT NOT NULL DEFAULT 'active',
|
conn.commit()
|
||||||
expires_at TEXT,
|
try:
|
||||||
created_at TEXT DEFAULT (datetime('now'))
|
return cur.fetchone()
|
||||||
)
|
except Exception:
|
||||||
""")
|
return None
|
||||||
cur.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS invite_usage (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
code_id INTEGER NOT NULL,
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
used_at TEXT DEFAULT (datetime('now')),
|
|
||||||
FOREIGN KEY (code_id) REFERENCES invite_codes(id),
|
|
||||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
cur.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
user_id INTEGER NOT NULL,
|
|
||||||
token TEXT UNIQUE NOT NULL,
|
|
||||||
expires_at TEXT NOT NULL,
|
|
||||||
revoked INTEGER NOT NULL DEFAULT 0,
|
|
||||||
created_at TEXT DEFAULT (datetime('now')),
|
|
||||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
# migrate: add role/banned columns if missing
|
|
||||||
try:
|
|
||||||
cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
try:
|
|
||||||
cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Password utils ──────────────────────────────────────────
|
# ─── Password utils ──────────────────────────────────────────
|
||||||
@ -144,13 +151,10 @@ def create_access_token(user_id: int, email: str, role: str) -> str:
|
|||||||
def create_refresh_token(user_id: int) -> str:
|
def create_refresh_token(user_id: int) -> str:
|
||||||
token = secrets.token_urlsafe(48)
|
token = secrets.token_urlsafe(48)
|
||||||
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
|
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
|
||||||
conn = get_conn()
|
_execute(
|
||||||
conn.execute(
|
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
|
||||||
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)",
|
|
||||||
(user_id, token, expires_at)
|
(user_id, token, expires_at)
|
||||||
)
|
)
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
@ -173,7 +177,6 @@ def parse_token(token: str) -> Optional[dict]:
|
|||||||
# ─── Auth dependency ─────────────────────────────────────────
|
# ─── Auth dependency ─────────────────────────────────────────
|
||||||
|
|
||||||
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
|
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
|
||||||
"""FastAPI dependency: returns user dict or raises 401"""
|
|
||||||
if not authorization or not authorization.startswith("Bearer "):
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="missing token")
|
raise HTTPException(status_code=401, detail="missing token")
|
||||||
token = authorization.split(" ", 1)[1].strip()
|
token = authorization.split(" ", 1)[1].strip()
|
||||||
@ -182,18 +185,15 @@ def get_current_user(authorization: Optional[str] = Header(default=None)) -> dic
|
|||||||
raise HTTPException(status_code=401, detail="invalid or expired token")
|
raise HTTPException(status_code=401, detail="invalid or expired token")
|
||||||
if payload.get("type") != "access":
|
if payload.get("type") != "access":
|
||||||
raise HTTPException(status_code=401, detail="invalid token type")
|
raise HTTPException(status_code=401, detail="invalid token type")
|
||||||
conn = get_conn()
|
user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
|
||||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="user not found")
|
raise HTTPException(status_code=401, detail="user not found")
|
||||||
if user["banned"]:
|
if user["banned"]:
|
||||||
raise HTTPException(status_code=403, detail="account banned")
|
raise HTTPException(status_code=403, detail="account banned")
|
||||||
return dict(user)
|
return user
|
||||||
|
|
||||||
|
|
||||||
def require_admin(user: dict = Depends(get_current_user)) -> dict:
|
def require_admin(user: dict = Depends(get_current_user)) -> dict:
|
||||||
"""FastAPI dependency: requires admin role"""
|
|
||||||
if user.get("role") != "admin":
|
if user.get("role") != "admin":
|
||||||
raise HTTPException(status_code=403, detail="admin required")
|
raise HTTPException(status_code=403, detail="admin required")
|
||||||
return user
|
return user
|
||||||
@ -230,60 +230,50 @@ class BanUserReq(BaseModel):
|
|||||||
@router.post("/auth/register")
|
@router.post("/auth/register")
|
||||||
def register(body: RegisterReq):
|
def register(body: RegisterReq):
|
||||||
ensure_tables()
|
ensure_tables()
|
||||||
conn = get_conn()
|
invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
|
||||||
|
|
||||||
# verify invite code
|
|
||||||
invite = conn.execute(
|
|
||||||
"SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,)
|
|
||||||
).fetchone()
|
|
||||||
if not invite:
|
if not invite:
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=400, detail="invalid invite code")
|
raise HTTPException(status_code=400, detail="invalid invite code")
|
||||||
if invite["status"] != "active":
|
if invite["status"] != "active":
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=400, detail="invite code disabled")
|
raise HTTPException(status_code=400, detail="invite code disabled")
|
||||||
if invite["used_count"] >= invite["max_uses"]:
|
if invite["used_count"] >= invite["max_uses"]:
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=400, detail="invite code exhausted")
|
raise HTTPException(status_code=400, detail="invite code exhausted")
|
||||||
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
|
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=400, detail="invite code expired")
|
raise HTTPException(status_code=400, detail="invite code expired")
|
||||||
|
|
||||||
|
pwd_hash = hash_password(body.password)
|
||||||
try:
|
try:
|
||||||
pwd_hash = hash_password(body.password)
|
with get_sync_conn() as conn:
|
||||||
cur = conn.execute(
|
with conn.cursor() as cur:
|
||||||
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)",
|
cur.execute(
|
||||||
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
|
||||||
)
|
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
||||||
user_id = cur.lastrowid
|
)
|
||||||
conn.execute(
|
user_id = cur.fetchone()[0]
|
||||||
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
|
cur.execute(
|
||||||
(user_id,),
|
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
|
||||||
)
|
(user_id,),
|
||||||
# record invite usage
|
)
|
||||||
conn.execute(
|
cur.execute(
|
||||||
"INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)",
|
"INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
|
||||||
(invite["id"], user_id),
|
(invite["id"], user_id),
|
||||||
)
|
)
|
||||||
conn.execute(
|
cur.execute(
|
||||||
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?",
|
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
|
||||||
(invite["id"],),
|
(invite["id"],),
|
||||||
)
|
)
|
||||||
# auto-exhaust if max reached
|
new_count = invite["used_count"] + 1
|
||||||
new_count = invite["used_count"] + 1
|
if new_count >= invite["max_uses"]:
|
||||||
if new_count >= invite["max_uses"]:
|
cur.execute(
|
||||||
conn.execute(
|
"UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
|
||||||
"UPDATE invite_codes SET status = 'exhausted' WHERE id = ?",
|
(invite["id"],),
|
||||||
(invite["id"],),
|
)
|
||||||
)
|
conn.commit()
|
||||||
conn.commit()
|
except Exception as e:
|
||||||
except sqlite3.IntegrityError:
|
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
|
||||||
conn.close()
|
raise HTTPException(status_code=400, detail="email already registered")
|
||||||
raise HTTPException(status_code=400, detail="email already registered")
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
# issue tokens
|
user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
|
||||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
access = create_access_token(user["id"], user["email"], user["role"])
|
access = create_access_token(user["id"], user["email"], user["role"])
|
||||||
refresh = create_refresh_token(user["id"])
|
refresh = create_refresh_token(user["id"])
|
||||||
return {
|
return {
|
||||||
@ -298,9 +288,7 @@ def register(body: RegisterReq):
|
|||||||
@router.post("/auth/login")
|
@router.post("/auth/login")
|
||||||
def login(body: LoginReq):
|
def login(body: LoginReq):
|
||||||
ensure_tables()
|
ensure_tables()
|
||||||
conn = get_conn()
|
user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
|
||||||
user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
if not user or not verify_password(body.password, user["password_hash"]):
|
if not user or not verify_password(body.password, user["password_hash"]):
|
||||||
raise HTTPException(status_code=401, detail="invalid credentials")
|
raise HTTPException(status_code=401, detail="invalid credentials")
|
||||||
if user["banned"]:
|
if user["banned"]:
|
||||||
@ -318,24 +306,17 @@ def login(body: LoginReq):
|
|||||||
|
|
||||||
@router.post("/auth/refresh")
|
@router.post("/auth/refresh")
|
||||||
def refresh_token(body: RefreshReq):
|
def refresh_token(body: RefreshReq):
|
||||||
conn = get_conn()
|
row = _fetchone(
|
||||||
row = conn.execute(
|
"SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
|
||||||
"SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,)
|
)
|
||||||
).fetchone()
|
|
||||||
if not row:
|
if not row:
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=401, detail="invalid refresh token")
|
raise HTTPException(status_code=401, detail="invalid refresh token")
|
||||||
if row["expires_at"] < datetime.utcnow().isoformat():
|
if row["expires_at"] < datetime.utcnow().isoformat():
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=401, detail="refresh token expired")
|
raise HTTPException(status_code=401, detail="refresh token expired")
|
||||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone()
|
user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
|
||||||
if not user or user["banned"]:
|
if not user or user["banned"]:
|
||||||
conn.close()
|
|
||||||
raise HTTPException(status_code=403, detail="account unavailable")
|
raise HTTPException(status_code=403, detail="account unavailable")
|
||||||
# revoke old, issue new
|
_execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
|
||||||
conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],))
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
access = create_access_token(user["id"], user["email"], user["role"])
|
access = create_access_token(user["id"], user["email"], user["role"])
|
||||||
new_refresh = create_refresh_token(user["id"])
|
new_refresh = create_refresh_token(user["id"])
|
||||||
return {
|
return {
|
||||||
@ -348,9 +329,7 @@ def refresh_token(body: RefreshReq):
|
|||||||
|
|
||||||
@router.get("/auth/me")
|
@router.get("/auth/me")
|
||||||
def me(user: dict = Depends(get_current_user)):
|
def me(user: dict = Depends(get_current_user)):
|
||||||
conn = get_conn()
|
sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
|
||||||
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
|
|
||||||
conn.close()
|
|
||||||
return {
|
return {
|
||||||
"id": user["id"], "email": user["email"], "role": user["role"],
|
"id": user["id"], "email": user["email"], "role": user["role"],
|
||||||
"discord_id": user.get("discord_id"),
|
"discord_id": user.get("discord_id"),
|
||||||
@ -363,53 +342,43 @@ def me(user: dict = Depends(get_current_user)):
|
|||||||
|
|
||||||
@router.post("/admin/invite-codes")
|
@router.post("/admin/invite-codes")
|
||||||
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
|
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
|
||||||
conn = get_conn()
|
|
||||||
codes = []
|
codes = []
|
||||||
for _ in range(body.count):
|
with get_sync_conn() as conn:
|
||||||
code = secrets.token_urlsafe(6)[:8].upper()
|
with conn.cursor() as cur:
|
||||||
conn.execute(
|
for _ in range(body.count):
|
||||||
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)",
|
code = secrets.token_urlsafe(6)[:8].upper()
|
||||||
(code, admin["id"], body.max_uses),
|
cur.execute(
|
||||||
)
|
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
|
||||||
codes.append(code)
|
(code, admin["id"], body.max_uses),
|
||||||
conn.commit()
|
)
|
||||||
conn.close()
|
codes.append(code)
|
||||||
|
conn.commit()
|
||||||
return {"codes": codes}
|
return {"codes": codes}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/invite-codes")
|
@router.get("/admin/invite-codes")
|
||||||
def list_invite_codes(admin: dict = Depends(require_admin)):
|
def list_invite_codes(admin: dict = Depends(require_admin)):
|
||||||
conn = get_conn()
|
rows = _fetchall(
|
||||||
rows = conn.execute(
|
|
||||||
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
|
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
|
||||||
).fetchall()
|
)
|
||||||
conn.close()
|
return {"items": rows}
|
||||||
return {"items": [dict(r) for r in rows]}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/admin/invite-codes/{code_id}")
|
@router.delete("/admin/invite-codes/{code_id}")
|
||||||
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
|
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
|
||||||
conn = get_conn()
|
_execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
|
||||||
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,))
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/users")
|
@router.get("/admin/users")
|
||||||
def list_users(admin: dict = Depends(require_admin)):
|
def list_users(admin: dict = Depends(require_admin)):
|
||||||
conn = get_conn()
|
rows = _fetchall(
|
||||||
rows = conn.execute(
|
|
||||||
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
|
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
|
||||||
).fetchall()
|
)
|
||||||
conn.close()
|
return {"items": rows}
|
||||||
return {"items": [dict(r) for r in rows]}
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/admin/users/{user_id}/ban")
|
@router.put("/admin/users/{user_id}/ban")
|
||||||
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
|
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
|
||||||
conn = get_conn()
|
_execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
|
||||||
conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id))
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
return {"ok": True, "user_id": user_id, "banned": body.banned}
|
return {"ok": True, "user_id": user_id, "banned": body.banned}
|
||||||
|
|||||||
170
backend/migrate_auth_sqlite_to_pg.py
Normal file
170
backend/migrate_auth_sqlite_to_pg.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
migrate_auth_sqlite_to_pg.py — 将SQLite中的auth相关表迁移到PG
|
||||||
|
运行前确保PG连接参数正确(db.py中的配置)
|
||||||
|
"""
|
||||||
|
import os, sys, sqlite3
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
from db import get_sync_conn
|
||||||
|
|
||||||
|
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||||
|
|
||||||
|
|
||||||
|
def migrate():
|
||||||
|
if not os.path.exists(SQLITE_PATH):
|
||||||
|
print(f"SQLite DB not found: {SQLITE_PATH}")
|
||||||
|
return
|
||||||
|
|
||||||
|
sq = sqlite3.connect(SQLITE_PATH)
|
||||||
|
sq.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
with get_sync_conn() as pg:
|
||||||
|
cur = pg.cursor()
|
||||||
|
|
||||||
|
# 1. 建auth相关表
|
||||||
|
print("Creating auth tables in PG...")
|
||||||
|
auth_tables = """
|
||||||
|
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||||
|
user_id BIGINT PRIMARY KEY,
|
||||||
|
tier TEXT NOT NULL DEFAULT 'free',
|
||||||
|
expires_at TEXT
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_usage (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
code_id BIGINT NOT NULL,
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
used_at TEXT
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
token TEXT UNIQUE NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
revoked INTEGER NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
for stmt in auth_tables.split(";"):
|
||||||
|
stmt = stmt.strip()
|
||||||
|
if stmt:
|
||||||
|
try:
|
||||||
|
cur.execute(stmt)
|
||||||
|
except Exception as e:
|
||||||
|
pg.rollback()
|
||||||
|
pg.commit()
|
||||||
|
|
||||||
|
# 2. 迁移users(PG已有users表但可能是空的旧结构)
|
||||||
|
print("Migrating users...")
|
||||||
|
# 先加缺失列
|
||||||
|
for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]:
|
||||||
|
try:
|
||||||
|
cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}")
|
||||||
|
pg.commit()
|
||||||
|
except:
|
||||||
|
pg.rollback()
|
||||||
|
|
||||||
|
rows = sq.execute("SELECT * FROM users").fetchall()
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
password_hash = EXCLUDED.password_hash,
|
||||||
|
discord_id = EXCLUDED.discord_id,
|
||||||
|
role = EXCLUDED.role,
|
||||||
|
banned = EXCLUDED.banned""",
|
||||||
|
(r["id"], r["email"], r["password_hash"],
|
||||||
|
r["discord_id"] if "discord_id" in r.keys() else None,
|
||||||
|
r["role"], r["banned"], r["created_at"])
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" User {r['email']} error: {e}")
|
||||||
|
pg.rollback()
|
||||||
|
continue
|
||||||
|
pg.commit()
|
||||||
|
print(f" Migrated {len(rows)} users")
|
||||||
|
|
||||||
|
# 3. 迁移invite_codes(PG已有但可能缺列)
|
||||||
|
print("Migrating invite_codes...")
|
||||||
|
for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"),
|
||||||
|
("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"),
|
||||||
|
("expires_at", "TEXT")]:
|
||||||
|
try:
|
||||||
|
cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}")
|
||||||
|
pg.commit()
|
||||||
|
except:
|
||||||
|
pg.rollback()
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = sq.execute("SELECT * FROM invite_codes").fetchall()
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"""INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at)
|
||||||
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
|
ON CONFLICT (id) DO NOTHING""",
|
||||||
|
(r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None,
|
||||||
|
r["max_uses"], r["used_count"], r["status"],
|
||||||
|
r["expires_at"] if "expires_at" in r.keys() else None,
|
||||||
|
r["created_at"] if "created_at" in r.keys() else None)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" Invite {r['code']} error: {e}")
|
||||||
|
pg.rollback()
|
||||||
|
continue
|
||||||
|
pg.commit()
|
||||||
|
print(f" Migrated {len(rows)} invite codes")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" invite_codes table error: {e}")
|
||||||
|
|
||||||
|
# 4. 迁移subscriptions
|
||||||
|
print("Migrating subscriptions...")
|
||||||
|
try:
|
||||||
|
rows = sq.execute("SELECT * FROM subscriptions").fetchall()
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING",
|
||||||
|
(r["user_id"], r["tier"], r["expires_at"])
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pg.rollback()
|
||||||
|
pg.commit()
|
||||||
|
print(f" Migrated {len(rows)} subscriptions")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" subscriptions error: {e}")
|
||||||
|
|
||||||
|
# 5. 迁移refresh_tokens
|
||||||
|
print("Migrating refresh_tokens...")
|
||||||
|
try:
|
||||||
|
rows = sq.execute("SELECT * FROM refresh_tokens").fetchall()
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING",
|
||||||
|
(r["user_id"], r["token"], r["expires_at"], r["revoked"])
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pg.rollback()
|
||||||
|
pg.commit()
|
||||||
|
print(f" Migrated {len(rows)} refresh tokens")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" refresh_tokens error: {e}")
|
||||||
|
|
||||||
|
# 6. 重置序列
|
||||||
|
print("Resetting sequences...")
|
||||||
|
for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]:
|
||||||
|
try:
|
||||||
|
cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}")
|
||||||
|
pg.commit()
|
||||||
|
except:
|
||||||
|
pg.rollback()
|
||||||
|
|
||||||
|
sq.close()
|
||||||
|
print("\nAuth migration complete!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
migrate()
|
||||||
Loading…
Reference in New Issue
Block a user