import os import hashlib import secrets import hmac import base64 import json from datetime import datetime, timedelta from typing import Optional from fastapi import APIRouter, HTTPException, Header, Depends, Request from pydantic import BaseModel, EmailStr from db import get_sync_conn _TRADE_ENV = os.getenv("TRADE_ENV", "testnet") _jwt_default = "arb-engine-jwt-secret-v2-2026" if _TRADE_ENV == "testnet" else None JWT_SECRET = os.getenv("JWT_SECRET") or _jwt_default if not JWT_SECRET or len(JWT_SECRET) < 32: raise RuntimeError("JWT_SECRET 未配置或长度不足(>=32),生产环境必须设置环境变量") ACCESS_TOKEN_HOURS = 24 REFRESH_TOKEN_DAYS = 7 router = APIRouter(prefix="/api", tags=["auth"]) # ─── PG Schema ─────────────────────────────────────────────── AUTH_SCHEMA = """ CREATE TABLE IF NOT EXISTS users ( id BIGSERIAL PRIMARY KEY, email TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, discord_id TEXT, role TEXT NOT NULL DEFAULT 'user', banned INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS subscriptions ( user_id BIGINT PRIMARY KEY REFERENCES users(id), tier TEXT NOT NULL DEFAULT 'free', expires_at TEXT ); CREATE TABLE IF NOT EXISTS invite_codes ( id BIGSERIAL PRIMARY KEY, code TEXT UNIQUE NOT NULL, created_by INTEGER, max_uses INTEGER NOT NULL DEFAULT 1, used_count INTEGER NOT NULL DEFAULT 0, status TEXT NOT NULL DEFAULT 'active', expires_at TEXT, created_at TEXT DEFAULT (NOW()::TEXT) ); CREATE TABLE IF NOT EXISTS invite_usage ( id BIGSERIAL PRIMARY KEY, code_id BIGINT NOT NULL REFERENCES invite_codes(id), user_id BIGINT NOT NULL REFERENCES users(id), used_at TEXT DEFAULT (NOW()::TEXT) ); CREATE TABLE IF NOT EXISTS refresh_tokens ( id BIGSERIAL PRIMARY KEY, user_id BIGINT NOT NULL REFERENCES users(id), token TEXT UNIQUE NOT NULL, expires_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0, created_at TEXT DEFAULT (NOW()::TEXT) ); """ def ensure_tables(): with get_sync_conn() as conn: with conn.cursor() as cur: for stmt in AUTH_SCHEMA.split(";"): stmt = stmt.strip() if stmt: try: cur.execute(stmt) except Exception: conn.rollback() continue conn.commit() # ─── DB helper ─────────────────────────────────────────────── def _fetchone(sql, params=None): with get_sync_conn() as conn: with conn.cursor() as cur: cur.execute(sql, params or ()) row = cur.fetchone() if not row: return None cols = [desc[0] for desc in cur.description] return dict(zip(cols, row)) def _fetchall(sql, params=None): with get_sync_conn() as conn: with conn.cursor() as cur: cur.execute(sql, params or ()) cols = [desc[0] for desc in cur.description] return [dict(zip(cols, row)) for row in cur.fetchall()] def _execute(sql, params=None): with get_sync_conn() as conn: with conn.cursor() as cur: cur.execute(sql, params or ()) conn.commit() try: return cur.fetchone() except Exception: return None # ─── Password utils ────────────────────────────────────────── def hash_password(password: str) -> str: salt = secrets.token_hex(16) digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() return f"{salt}${digest}" def verify_password(password: str, stored: str) -> bool: try: salt, digest = stored.split("$", 1) candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() return hmac.compare_digest(candidate, digest) except Exception: return False # ─── JWT utils ─────────────────────────────────────────────── def b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() def create_access_token(user_id: int, email: str, role: str) -> str: header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode()) exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp()) payload = b64url(json.dumps({ "sub": user_id, "email": email, "role": role, "exp": exp, "type": "access" }, separators=(",", ":")).encode()) sign_input = f"{header}.{payload}".encode() signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest() return f"{header}.{payload}.{b64url(signature)}" def create_refresh_token(user_id: int) -> str: token = secrets.token_urlsafe(48) expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat() _execute( "INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user_id, token, expires_at) ) return token def parse_token(token: str) -> Optional[dict]: try: header_b64, payload_b64, sig_b64 = token.split(".") sign_input = f"{header_b64}.{payload_b64}".encode() expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()) if not hmac.compare_digest(expected, sig_b64): return None pad = '=' * (-len(payload_b64) % 4) payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad)) if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()): return None return payload except Exception: return None # ─── Auth dependency ───────────────────────────────────────── def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict: if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing token") token = authorization.split(" ", 1)[1].strip() payload = parse_token(token) if not payload: raise HTTPException(status_code=401, detail="invalid or expired token") if payload.get("type") != "access": raise HTTPException(status_code=401, detail="invalid token type") user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],)) if not user: raise HTTPException(status_code=401, detail="user not found") if user["banned"]: raise HTTPException(status_code=403, detail="account banned") return user def require_admin(user: dict = Depends(get_current_user)) -> dict: if user.get("role") != "admin": raise HTTPException(status_code=403, detail="admin required") return user # ─── Request models ────────────────────────────────────────── class RegisterReq(BaseModel): email: EmailStr password: str invite_code: str class LoginReq(BaseModel): email: EmailStr password: str class RefreshReq(BaseModel): refresh_token: str class GenInviteReq(BaseModel): count: int = 1 max_uses: int = 1 class BanUserReq(BaseModel): banned: bool = True # ─── Auth routes ───────────────────────────────────────────── @router.post("/auth/register") def register(body: RegisterReq): ensure_tables() invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,)) if not invite: raise HTTPException(status_code=400, detail="invalid invite code") if invite["status"] != "active": raise HTTPException(status_code=400, detail="invite code disabled") if invite["used_count"] >= invite["max_uses"]: raise HTTPException(status_code=400, detail="invite code exhausted") if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat(): raise HTTPException(status_code=400, detail="invite code expired") pwd_hash = hash_password(body.password) try: with get_sync_conn() as conn: with conn.cursor() as cur: cur.execute( "INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id", (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), ) user_id = cur.fetchone()[0] cur.execute( "INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING", (user_id,), ) cur.execute( "INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)", (invite["id"], user_id), ) cur.execute( "UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s", (invite["id"],), ) new_count = invite["used_count"] + 1 if new_count >= invite["max_uses"]: cur.execute( "UPDATE invite_codes SET status = 'exhausted' WHERE id = %s", (invite["id"],), ) conn.commit() except Exception as e: if "unique" in str(e).lower() or "duplicate" in str(e).lower(): raise HTTPException(status_code=400, detail="email already registered") raise HTTPException(status_code=500, detail=str(e)) user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,)) access = create_access_token(user["id"], user["email"], user["role"]) refresh = create_refresh_token(user["id"]) return { "access_token": access, "refresh_token": refresh, "token_type": "bearer", "expires_in": ACCESS_TOKEN_HOURS * 3600, "user": {"id": user["id"], "email": user["email"], "role": user["role"]}, } @router.post("/auth/login") def login(body: LoginReq): ensure_tables() user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),)) if not user or not verify_password(body.password, user["password_hash"]): raise HTTPException(status_code=401, detail="invalid credentials") if user["banned"]: raise HTTPException(status_code=403, detail="account banned") access = create_access_token(user["id"], user["email"], user["role"]) refresh = create_refresh_token(user["id"]) return { "access_token": access, "refresh_token": refresh, "token_type": "bearer", "expires_in": ACCESS_TOKEN_HOURS * 3600, "user": {"id": user["id"], "email": user["email"], "role": user["role"]}, } @router.post("/auth/refresh") def refresh_token(body: RefreshReq): row = _fetchone( "SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,) ) if not row: raise HTTPException(status_code=401, detail="invalid refresh token") if row["expires_at"] < datetime.utcnow().isoformat(): raise HTTPException(status_code=401, detail="refresh token expired") user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],)) if not user or user["banned"]: raise HTTPException(status_code=403, detail="account unavailable") _execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],)) access = create_access_token(user["id"], user["email"], user["role"]) new_refresh = create_refresh_token(user["id"]) return { "access_token": access, "refresh_token": new_refresh, "token_type": "bearer", "expires_in": ACCESS_TOKEN_HOURS * 3600, } @router.get("/auth/me") def me(user: dict = Depends(get_current_user)): sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],)) return { "id": user["id"], "email": user["email"], "role": user["role"], "discord_id": user.get("discord_id"), "created_at": user["created_at"], "subscription": dict(sub) if sub else {"tier": "free", "expires_at": None}, } # ─── Admin routes ──────────────────────────────────────────── @router.post("/admin/invite-codes") def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)): codes = [] with get_sync_conn() as conn: with conn.cursor() as cur: for _ in range(body.count): code = secrets.token_urlsafe(6)[:8].upper() cur.execute( "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)", (code, admin["id"], body.max_uses), ) codes.append(code) conn.commit() return {"codes": codes} @router.get("/admin/invite-codes") def list_invite_codes(admin: dict = Depends(require_admin)): rows = _fetchall( "SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC" ) return {"items": rows} @router.delete("/admin/invite-codes/{code_id}") def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)): _execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,)) return {"ok": True} @router.get("/admin/users") def list_users(admin: dict = Depends(require_admin)): rows = _fetchall( "SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC" ) return {"items": rows} @router.put("/admin/users/{user_id}/ban") def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)): _execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id)) return {"ok": True, "user_id": user_id, "banned": body.banned}