- auth.py: rewrite to use PG via db.py (was sqlite3) - admin_cli.py: rewrite to use PG - migrate_auth_sqlite_to_pg.py: one-time migration script - SQLite arb.db no longer needed after migration
385 lines
14 KiB
Python
385 lines
14 KiB
Python
import os
|
|
import hashlib
|
|
import secrets
|
|
import hmac
|
|
import base64
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, Header, Depends, Request
|
|
from pydantic import BaseModel, EmailStr
|
|
|
|
from db import get_sync_conn
|
|
|
|
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
|
|
ACCESS_TOKEN_HOURS = 24
|
|
REFRESH_TOKEN_DAYS = 7
|
|
|
|
router = APIRouter(prefix="/api", tags=["auth"])
|
|
|
|
|
|
# ─── PG Schema ───────────────────────────────────────────────
|
|
|
|
AUTH_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
email TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT NOT NULL,
|
|
discord_id TEXT,
|
|
role TEXT NOT NULL DEFAULT 'user',
|
|
banned INTEGER NOT NULL DEFAULT 0,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS subscriptions (
|
|
user_id BIGINT PRIMARY KEY REFERENCES users(id),
|
|
tier TEXT NOT NULL DEFAULT 'free',
|
|
expires_at TEXT
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS invite_codes (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
code TEXT UNIQUE NOT NULL,
|
|
created_by INTEGER,
|
|
max_uses INTEGER NOT NULL DEFAULT 1,
|
|
used_count INTEGER NOT NULL DEFAULT 0,
|
|
status TEXT NOT NULL DEFAULT 'active',
|
|
expires_at TEXT,
|
|
created_at TEXT DEFAULT (NOW()::TEXT)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS invite_usage (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
code_id BIGINT NOT NULL REFERENCES invite_codes(id),
|
|
user_id BIGINT NOT NULL REFERENCES users(id),
|
|
used_at TEXT DEFAULT (NOW()::TEXT)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
user_id BIGINT NOT NULL REFERENCES users(id),
|
|
token TEXT UNIQUE NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
revoked INTEGER NOT NULL DEFAULT 0,
|
|
created_at TEXT DEFAULT (NOW()::TEXT)
|
|
);
|
|
"""
|
|
|
|
|
|
def ensure_tables():
|
|
with get_sync_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
for stmt in AUTH_SCHEMA.split(";"):
|
|
stmt = stmt.strip()
|
|
if stmt:
|
|
try:
|
|
cur.execute(stmt)
|
|
except Exception:
|
|
conn.rollback()
|
|
continue
|
|
conn.commit()
|
|
|
|
|
|
# ─── DB helper ───────────────────────────────────────────────
|
|
|
|
def _fetchone(sql, params=None):
|
|
with get_sync_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
row = cur.fetchone()
|
|
if not row:
|
|
return None
|
|
cols = [desc[0] for desc in cur.description]
|
|
return dict(zip(cols, row))
|
|
|
|
|
|
def _fetchall(sql, params=None):
|
|
with get_sync_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
cols = [desc[0] for desc in cur.description]
|
|
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
|
|
|
|
|
def _execute(sql, params=None):
|
|
with get_sync_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(sql, params or ())
|
|
conn.commit()
|
|
try:
|
|
return cur.fetchone()
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
# ─── Password utils ──────────────────────────────────────────
|
|
|
|
def hash_password(password: str) -> str:
|
|
salt = secrets.token_hex(16)
|
|
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
|
return f"{salt}${digest}"
|
|
|
|
|
|
def verify_password(password: str, stored: str) -> bool:
|
|
try:
|
|
salt, digest = stored.split("$", 1)
|
|
candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
|
return hmac.compare_digest(candidate, digest)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
# ─── JWT utils ───────────────────────────────────────────────
|
|
|
|
def b64url(data: bytes) -> str:
|
|
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
|
|
|
|
|
def create_access_token(user_id: int, email: str, role: str) -> str:
|
|
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
|
|
exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
|
|
payload = b64url(json.dumps({
|
|
"sub": user_id, "email": email, "role": role,
|
|
"exp": exp, "type": "access"
|
|
}, separators=(",", ":")).encode())
|
|
sign_input = f"{header}.{payload}".encode()
|
|
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
|
|
return f"{header}.{payload}.{b64url(signature)}"
|
|
|
|
|
|
def create_refresh_token(user_id: int) -> str:
|
|
token = secrets.token_urlsafe(48)
|
|
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
|
|
_execute(
|
|
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
|
|
(user_id, token, expires_at)
|
|
)
|
|
return token
|
|
|
|
|
|
def parse_token(token: str) -> Optional[dict]:
|
|
try:
|
|
header_b64, payload_b64, sig_b64 = token.split(".")
|
|
sign_input = f"{header_b64}.{payload_b64}".encode()
|
|
expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest())
|
|
if not hmac.compare_digest(expected, sig_b64):
|
|
return None
|
|
pad = '=' * (-len(payload_b64) % 4)
|
|
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad))
|
|
if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()):
|
|
return None
|
|
return payload
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
# ─── Auth dependency ─────────────────────────────────────────
|
|
|
|
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="missing token")
|
|
token = authorization.split(" ", 1)[1].strip()
|
|
payload = parse_token(token)
|
|
if not payload:
|
|
raise HTTPException(status_code=401, detail="invalid or expired token")
|
|
if payload.get("type") != "access":
|
|
raise HTTPException(status_code=401, detail="invalid token type")
|
|
user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
|
|
if not user:
|
|
raise HTTPException(status_code=401, detail="user not found")
|
|
if user["banned"]:
|
|
raise HTTPException(status_code=403, detail="account banned")
|
|
return user
|
|
|
|
|
|
def require_admin(user: dict = Depends(get_current_user)) -> dict:
|
|
if user.get("role") != "admin":
|
|
raise HTTPException(status_code=403, detail="admin required")
|
|
return user
|
|
|
|
|
|
# ─── Request models ──────────────────────────────────────────
|
|
|
|
class RegisterReq(BaseModel):
|
|
email: EmailStr
|
|
password: str
|
|
invite_code: str
|
|
|
|
|
|
class LoginReq(BaseModel):
|
|
email: EmailStr
|
|
password: str
|
|
|
|
|
|
class RefreshReq(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
class GenInviteReq(BaseModel):
|
|
count: int = 1
|
|
max_uses: int = 1
|
|
|
|
|
|
class BanUserReq(BaseModel):
|
|
banned: bool = True
|
|
|
|
|
|
# ─── Auth routes ─────────────────────────────────────────────
|
|
|
|
@router.post("/auth/register")
|
|
def register(body: RegisterReq):
|
|
ensure_tables()
|
|
invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
|
|
if not invite:
|
|
raise HTTPException(status_code=400, detail="invalid invite code")
|
|
if invite["status"] != "active":
|
|
raise HTTPException(status_code=400, detail="invite code disabled")
|
|
if invite["used_count"] >= invite["max_uses"]:
|
|
raise HTTPException(status_code=400, detail="invite code exhausted")
|
|
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
|
|
raise HTTPException(status_code=400, detail="invite code expired")
|
|
|
|
pwd_hash = hash_password(body.password)
|
|
try:
|
|
with get_sync_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
|
|
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
|
)
|
|
user_id = cur.fetchone()[0]
|
|
cur.execute(
|
|
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
|
|
(user_id,),
|
|
)
|
|
cur.execute(
|
|
"INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
|
|
(invite["id"], user_id),
|
|
)
|
|
cur.execute(
|
|
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
|
|
(invite["id"],),
|
|
)
|
|
new_count = invite["used_count"] + 1
|
|
if new_count >= invite["max_uses"]:
|
|
cur.execute(
|
|
"UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
|
|
(invite["id"],),
|
|
)
|
|
conn.commit()
|
|
except Exception as e:
|
|
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
|
|
raise HTTPException(status_code=400, detail="email already registered")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
|
|
access = create_access_token(user["id"], user["email"], user["role"])
|
|
refresh = create_refresh_token(user["id"])
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": refresh,
|
|
"token_type": "bearer",
|
|
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
|
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
|
}
|
|
|
|
|
|
@router.post("/auth/login")
|
|
def login(body: LoginReq):
|
|
ensure_tables()
|
|
user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
|
|
if not user or not verify_password(body.password, user["password_hash"]):
|
|
raise HTTPException(status_code=401, detail="invalid credentials")
|
|
if user["banned"]:
|
|
raise HTTPException(status_code=403, detail="account banned")
|
|
access = create_access_token(user["id"], user["email"], user["role"])
|
|
refresh = create_refresh_token(user["id"])
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": refresh,
|
|
"token_type": "bearer",
|
|
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
|
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
|
}
|
|
|
|
|
|
@router.post("/auth/refresh")
|
|
def refresh_token(body: RefreshReq):
|
|
row = _fetchone(
|
|
"SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
|
|
)
|
|
if not row:
|
|
raise HTTPException(status_code=401, detail="invalid refresh token")
|
|
if row["expires_at"] < datetime.utcnow().isoformat():
|
|
raise HTTPException(status_code=401, detail="refresh token expired")
|
|
user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
|
|
if not user or user["banned"]:
|
|
raise HTTPException(status_code=403, detail="account unavailable")
|
|
_execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
|
|
access = create_access_token(user["id"], user["email"], user["role"])
|
|
new_refresh = create_refresh_token(user["id"])
|
|
return {
|
|
"access_token": access,
|
|
"refresh_token": new_refresh,
|
|
"token_type": "bearer",
|
|
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
|
}
|
|
|
|
|
|
@router.get("/auth/me")
|
|
def me(user: dict = Depends(get_current_user)):
|
|
sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
|
|
return {
|
|
"id": user["id"], "email": user["email"], "role": user["role"],
|
|
"discord_id": user.get("discord_id"),
|
|
"created_at": user["created_at"],
|
|
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
|
|
}
|
|
|
|
|
|
# ─── Admin routes ────────────────────────────────────────────
|
|
|
|
@router.post("/admin/invite-codes")
|
|
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
|
|
codes = []
|
|
with get_sync_conn() as conn:
|
|
with conn.cursor() as cur:
|
|
for _ in range(body.count):
|
|
code = secrets.token_urlsafe(6)[:8].upper()
|
|
cur.execute(
|
|
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
|
|
(code, admin["id"], body.max_uses),
|
|
)
|
|
codes.append(code)
|
|
conn.commit()
|
|
return {"codes": codes}
|
|
|
|
|
|
@router.get("/admin/invite-codes")
|
|
def list_invite_codes(admin: dict = Depends(require_admin)):
|
|
rows = _fetchall(
|
|
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
|
|
)
|
|
return {"items": rows}
|
|
|
|
|
|
@router.delete("/admin/invite-codes/{code_id}")
|
|
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
|
|
_execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
|
|
return {"ok": True}
|
|
|
|
|
|
@router.get("/admin/users")
|
|
def list_users(admin: dict = Depends(require_admin)):
|
|
rows = _fetchall(
|
|
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
|
|
)
|
|
return {"items": rows}
|
|
|
|
|
|
@router.put("/admin/users/{user_id}/ban")
|
|
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
|
|
_execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
|
|
return {"ok": True, "user_id": user_id, "banned": body.banned}
|