arbitrage-engine/backend/auth.py

197 lines
6.3 KiB
Python

import os
import sqlite3
import hashlib
import secrets
import hmac
import base64
import json
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel, EmailStr
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me")
JWT_EXPIRE_HOURS = 24 * 7
router = APIRouter(prefix="/api", tags=["auth"])
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def ensure_tables():
conn = get_conn()
cur = conn.cursor()
cur.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
discord_id TEXT,
created_at TEXT NOT NULL
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS subscriptions (
user_id INTEGER PRIMARY KEY,
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
)
"""
)
cur.execute(
"""
CREATE TABLE IF NOT EXISTS signal_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
rate REAL NOT NULL,
annualized REAL NOT NULL,
sent_at TEXT NOT NULL,
message TEXT NOT NULL
)
"""
)
conn.commit()
conn.close()
def hash_password(password: str) -> str:
salt = secrets.token_hex(16)
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
return f"{salt}${digest}"
def verify_password(password: str, stored: str) -> bool:
try:
salt, digest = stored.split("$", 1)
candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
return hmac.compare_digest(candidate, digest)
except Exception:
return False
def b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def create_token(user_id: int, email: str) -> str:
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp())
payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode())
sign_input = f"{header}.{payload}".encode()
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
return f"{header}.{payload}.{b64url(signature)}"
def parse_token(token: str) -> Optional[dict]:
try:
header_b64, payload_b64, sig_b64 = token.split(".")
sign_input = f"{header_b64}.{payload_b64}".encode()
expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest())
if not hmac.compare_digest(expected, sig_b64):
return None
pad = '=' * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad))
if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()):
return None
return payload
except Exception:
return None
def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row:
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing token")
token = authorization.split(" ", 1)[1].strip()
payload = parse_token(token)
if not payload:
raise HTTPException(status_code=401, detail="invalid token")
conn = get_conn()
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
conn.close()
if not user:
raise HTTPException(status_code=401, detail="user not found")
return user
class RegisterReq(BaseModel):
email: EmailStr
password: str
class LoginReq(BaseModel):
email: EmailStr
password: str
class BindDiscordReq(BaseModel):
discord_id: str
@router.post("/auth/register")
def register(body: RegisterReq):
ensure_tables()
conn = get_conn()
try:
pwd_hash = hash_password(body.password)
cur = conn.execute(
"INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)",
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
)
user_id = cur.lastrowid
conn.execute(
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
(user_id,),
)
conn.commit()
except sqlite3.IntegrityError:
conn.close()
raise HTTPException(status_code=400, detail="email already registered")
user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone()
conn.close()
return dict(user)
@router.post("/auth/login")
def login(body: LoginReq):
ensure_tables()
conn = get_conn()
user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone()
conn.close()
if not user or not verify_password(body.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="invalid credentials")
token = create_token(user["id"], user["email"])
return {"token": token, "token_type": "bearer"}
@router.post("/user/bind-discord")
def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)):
user = get_user_from_auth(authorization)
conn = get_conn()
conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"]))
conn.commit()
updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone()
conn.close()
return dict(updated)
@router.get("/user/me")
def me(authorization: Optional[str] = Header(default=None)):
user = get_user_from_auth(authorization)
conn = get_conn()
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
conn.close()
out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]}
out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None}
return out