import os import sqlite3 import hashlib import secrets import hmac import base64 import json from datetime import datetime, timedelta from typing import Optional from fastapi import APIRouter, HTTPException, Header from pydantic import BaseModel, EmailStr DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me") JWT_EXPIRE_HOURS = 24 * 7 router = APIRouter(prefix="/api", tags=["auth"]) def get_conn(): conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row return conn def ensure_tables(): conn = get_conn() cur = conn.cursor() cur.execute( """ CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, discord_id TEXT, created_at TEXT NOT NULL ) """ ) cur.execute( """ CREATE TABLE IF NOT EXISTS subscriptions ( user_id INTEGER PRIMARY KEY, tier TEXT NOT NULL DEFAULT 'free', expires_at TEXT, FOREIGN KEY(user_id) REFERENCES users(id) ) """ ) cur.execute( """ CREATE TABLE IF NOT EXISTS signal_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, symbol TEXT NOT NULL, rate REAL NOT NULL, annualized REAL NOT NULL, sent_at TEXT NOT NULL, message TEXT NOT NULL ) """ ) conn.commit() conn.close() def hash_password(password: str) -> str: salt = secrets.token_hex(16) digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() return f"{salt}${digest}" def verify_password(password: str, stored: str) -> bool: try: salt, digest = stored.split("$", 1) candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() return hmac.compare_digest(candidate, digest) except Exception: return False def b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() def create_token(user_id: int, email: str) -> str: header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode()) exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp()) payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode()) sign_input = f"{header}.{payload}".encode() signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest() return f"{header}.{payload}.{b64url(signature)}" def parse_token(token: str) -> Optional[dict]: try: header_b64, payload_b64, sig_b64 = token.split(".") sign_input = f"{header_b64}.{payload_b64}".encode() expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()) if not hmac.compare_digest(expected, sig_b64): return None pad = '=' * (-len(payload_b64) % 4) payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad)) if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()): return None return payload except Exception: return None def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row: if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing token") token = authorization.split(" ", 1)[1].strip() payload = parse_token(token) if not payload: raise HTTPException(status_code=401, detail="invalid token") conn = get_conn() user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone() conn.close() if not user: raise HTTPException(status_code=401, detail="user not found") return user class RegisterReq(BaseModel): email: EmailStr password: str class LoginReq(BaseModel): email: EmailStr password: str class BindDiscordReq(BaseModel): discord_id: str @router.post("/auth/register") def register(body: RegisterReq): ensure_tables() conn = get_conn() try: pwd_hash = hash_password(body.password) cur = conn.execute( "INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)", (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), ) user_id = cur.lastrowid conn.execute( "INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", (user_id,), ) conn.commit() except sqlite3.IntegrityError: conn.close() raise HTTPException(status_code=400, detail="email already registered") user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone() conn.close() return dict(user) @router.post("/auth/login") def login(body: LoginReq): ensure_tables() conn = get_conn() user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone() conn.close() if not user or not verify_password(body.password, user["password_hash"]): raise HTTPException(status_code=401, detail="invalid credentials") token = create_token(user["id"], user["email"]) return {"token": token, "token_type": "bearer"} @router.post("/user/bind-discord") def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)): user = get_user_from_auth(authorization) conn = get_conn() conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"])) conn.commit() updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone() conn.close() return dict(updated) @router.get("/user/me") def me(authorization: Optional[str] = Header(default=None)): user = get_user_from_auth(authorization) conn = get_conn() sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone() conn.close() out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]} out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None} return out