import os import sqlite3 import hashlib import secrets import hmac import base64 import json from datetime import datetime, timedelta from typing import Optional from functools import wraps from fastapi import APIRouter, HTTPException, Header, Depends, Request from pydantic import BaseModel, EmailStr DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026") ACCESS_TOKEN_HOURS = 24 REFRESH_TOKEN_DAYS = 7 router = APIRouter(prefix="/api", tags=["auth"]) # ─── DB helpers ─────────────────────────────────────────────── def get_conn(): conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row return conn def ensure_tables(): conn = get_conn() cur = conn.cursor() cur.execute(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, discord_id TEXT, role TEXT NOT NULL DEFAULT 'user', banned INTEGER NOT NULL DEFAULT 0, created_at TEXT NOT NULL ) """) cur.execute(""" CREATE TABLE IF NOT EXISTS subscriptions ( user_id INTEGER PRIMARY KEY, tier TEXT NOT NULL DEFAULT 'free', expires_at TEXT, FOREIGN KEY(user_id) REFERENCES users(id) ) """) cur.execute(""" CREATE TABLE IF NOT EXISTS signal_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL, rate REAL NOT NULL, annualized REAL NOT NULL, sent_at TEXT NOT NULL, message TEXT NOT NULL ) """) cur.execute(""" CREATE TABLE IF NOT EXISTS invite_codes ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT UNIQUE NOT NULL, created_by INTEGER, max_uses INTEGER NOT NULL DEFAULT 1, used_count INTEGER NOT NULL DEFAULT 0, status TEXT NOT NULL DEFAULT 'active', expires_at TEXT, created_at TEXT DEFAULT (datetime('now')) ) """) cur.execute(""" CREATE TABLE IF NOT EXISTS invite_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, code_id INTEGER NOT NULL, user_id INTEGER NOT NULL, used_at TEXT DEFAULT (datetime('now')), FOREIGN KEY (code_id) REFERENCES invite_codes(id), FOREIGN KEY (user_id) REFERENCES users(id) ) """) cur.execute(""" CREATE TABLE IF NOT EXISTS refresh_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, token TEXT UNIQUE NOT NULL, expires_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0, created_at TEXT DEFAULT (datetime('now')), FOREIGN KEY (user_id) REFERENCES users(id) ) """) # migrate: add role/banned columns if missing try: cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'") except: pass try: cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0") except: pass conn.commit() conn.close() # ─── Password utils ────────────────────────────────────────── def hash_password(password: str) -> str: salt = secrets.token_hex(16) digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() return f"{salt}${digest}" def verify_password(password: str, stored: str) -> bool: try: salt, digest = stored.split("$", 1) candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() return hmac.compare_digest(candidate, digest) except Exception: return False # ─── JWT utils ─────────────────────────────────────────────── def b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() def create_access_token(user_id: int, email: str, role: str) -> str: header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode()) exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp()) payload = b64url(json.dumps({ "sub": user_id, "email": email, "role": role, "exp": exp, "type": "access" }, separators=(",", ":")).encode()) sign_input = f"{header}.{payload}".encode() signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest() return f"{header}.{payload}.{b64url(signature)}" def create_refresh_token(user_id: int) -> str: token = secrets.token_urlsafe(48) expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat() conn = get_conn() conn.execute( "INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)", (user_id, token, expires_at) ) conn.commit() conn.close() return token def parse_token(token: str) -> Optional[dict]: try: header_b64, payload_b64, sig_b64 = token.split(".") sign_input = f"{header_b64}.{payload_b64}".encode() expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()) if not hmac.compare_digest(expected, sig_b64): return None pad = '=' * (-len(payload_b64) % 4) payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad)) if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()): return None return payload except Exception: return None # ─── Auth dependency ───────────────────────────────────────── def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict: """FastAPI dependency: returns user dict or raises 401""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing token") token = authorization.split(" ", 1)[1].strip() payload = parse_token(token) if not payload: raise HTTPException(status_code=401, detail="invalid or expired token") if payload.get("type") != "access": raise HTTPException(status_code=401, detail="invalid token type") conn = get_conn() user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone() conn.close() if not user: raise HTTPException(status_code=401, detail="user not found") if user["banned"]: raise HTTPException(status_code=403, detail="account banned") return dict(user) def require_admin(user: dict = Depends(get_current_user)) -> dict: """FastAPI dependency: requires admin role""" if user.get("role") != "admin": raise HTTPException(status_code=403, detail="admin required") return user # ─── Request models ────────────────────────────────────────── class RegisterReq(BaseModel): email: EmailStr password: str invite_code: str class LoginReq(BaseModel): email: EmailStr password: str class RefreshReq(BaseModel): refresh_token: str class GenInviteReq(BaseModel): count: int = 1 max_uses: int = 1 class BanUserReq(BaseModel): banned: bool = True # ─── Auth routes ───────────────────────────────────────────── @router.post("/auth/register") def register(body: RegisterReq): ensure_tables() conn = get_conn() # verify invite code invite = conn.execute( "SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,) ).fetchone() if not invite: conn.close() raise HTTPException(status_code=400, detail="invalid invite code") if invite["status"] != "active": conn.close() raise HTTPException(status_code=400, detail="invite code disabled") if invite["used_count"] >= invite["max_uses"]: conn.close() raise HTTPException(status_code=400, detail="invite code exhausted") if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat(): conn.close() raise HTTPException(status_code=400, detail="invite code expired") try: pwd_hash = hash_password(body.password) cur = conn.execute( "INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)", (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), ) user_id = cur.lastrowid conn.execute( "INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", (user_id,), ) # record invite usage conn.execute( "INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)", (invite["id"], user_id), ) conn.execute( "UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?", (invite["id"],), ) # auto-exhaust if max reached new_count = invite["used_count"] + 1 if new_count >= invite["max_uses"]: conn.execute( "UPDATE invite_codes SET status = 'exhausted' WHERE id = ?", (invite["id"],), ) conn.commit() except sqlite3.IntegrityError: conn.close() raise HTTPException(status_code=400, detail="email already registered") # issue tokens user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone() conn.close() access = create_access_token(user["id"], user["email"], user["role"]) refresh = create_refresh_token(user["id"]) return { "access_token": access, "refresh_token": refresh, "token_type": "bearer", "expires_in": ACCESS_TOKEN_HOURS * 3600, "user": {"id": user["id"], "email": user["email"], "role": user["role"]}, } @router.post("/auth/login") def login(body: LoginReq): ensure_tables() conn = get_conn() user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone() conn.close() if not user or not verify_password(body.password, user["password_hash"]): raise HTTPException(status_code=401, detail="invalid credentials") if user["banned"]: raise HTTPException(status_code=403, detail="account banned") access = create_access_token(user["id"], user["email"], user["role"]) refresh = create_refresh_token(user["id"]) return { "access_token": access, "refresh_token": refresh, "token_type": "bearer", "expires_in": ACCESS_TOKEN_HOURS * 3600, "user": {"id": user["id"], "email": user["email"], "role": user["role"]}, } @router.post("/auth/refresh") def refresh_token(body: RefreshReq): conn = get_conn() row = conn.execute( "SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,) ).fetchone() if not row: conn.close() raise HTTPException(status_code=401, detail="invalid refresh token") if row["expires_at"] < datetime.utcnow().isoformat(): conn.close() raise HTTPException(status_code=401, detail="refresh token expired") user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone() if not user or user["banned"]: conn.close() raise HTTPException(status_code=403, detail="account unavailable") # revoke old, issue new conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],)) conn.commit() conn.close() access = create_access_token(user["id"], user["email"], user["role"]) new_refresh = create_refresh_token(user["id"]) return { "access_token": access, "refresh_token": new_refresh, "token_type": "bearer", "expires_in": ACCESS_TOKEN_HOURS * 3600, } @router.get("/auth/me") def me(user: dict = Depends(get_current_user)): conn = get_conn() sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone() conn.close() return { "id": user["id"], "email": user["email"], "role": user["role"], "discord_id": user.get("discord_id"), "created_at": user["created_at"], "subscription": dict(sub) if sub else {"tier": "free", "expires_at": None}, } # ─── Admin routes ──────────────────────────────────────────── @router.post("/admin/invite-codes") def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)): conn = get_conn() codes = [] for _ in range(body.count): code = secrets.token_urlsafe(6)[:8].upper() conn.execute( "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)", (code, admin["id"], body.max_uses), ) codes.append(code) conn.commit() conn.close() return {"codes": codes} @router.get("/admin/invite-codes") def list_invite_codes(admin: dict = Depends(require_admin)): conn = get_conn() rows = conn.execute( "SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC" ).fetchall() conn.close() return {"items": [dict(r) for r in rows]} @router.delete("/admin/invite-codes/{code_id}") def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)): conn = get_conn() conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,)) conn.commit() conn.close() return {"ok": True} @router.get("/admin/users") def list_users(admin: dict = Depends(require_admin)): conn = get_conn() rows = conn.execute( "SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC" ).fetchall() conn.close() return {"items": [dict(r) for r in rows]} @router.put("/admin/users/{user_id}/ban") def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)): conn = get_conn() conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id)) conn.commit() conn.close() return {"ok": True, "user_id": user_id, "banned": body.banned}