diff --git a/backend/auth.py b/backend/auth.py new file mode 100644 index 0000000..b405b48 --- /dev/null +++ b/backend/auth.py @@ -0,0 +1,196 @@ +import os +import sqlite3 +import hashlib +import secrets +import hmac +import base64 +import json +from datetime import datetime, timedelta +from typing import Optional + +from fastapi import APIRouter, HTTPException, Header +from pydantic import BaseModel, EmailStr + +DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") +JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me") +JWT_EXPIRE_HOURS = 24 * 7 + +router = APIRouter(prefix="/api", tags=["auth"]) + + +def get_conn(): + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + + +def ensure_tables(): + conn = get_conn() + cur = conn.cursor() + cur.execute( + """ + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + email TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + discord_id TEXT, + created_at TEXT NOT NULL + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS subscriptions ( + user_id INTEGER PRIMARY KEY, + tier TEXT NOT NULL DEFAULT 'free', + expires_at TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) + ) + """ + ) + cur.execute( + """ + CREATE TABLE IF NOT EXISTS signal_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + symbol TEXT NOT NULL, + rate REAL NOT NULL, + annualized REAL NOT NULL, + sent_at TEXT NOT NULL, + message TEXT NOT NULL + ) + """ + ) + conn.commit() + conn.close() + + +def hash_password(password: str) -> str: + salt = secrets.token_hex(16) + digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() + return f"{salt}${digest}" + + +def verify_password(password: str, stored: str) -> bool: + try: + salt, digest = stored.split("$", 1) + candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex() + return hmac.compare_digest(candidate, digest) + except Exception: + return False + + +def b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode() + + +def create_token(user_id: int, email: str) -> str: + header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode()) + exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp()) + payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode()) + sign_input = f"{header}.{payload}".encode() + signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest() + return f"{header}.{payload}.{b64url(signature)}" + + +def parse_token(token: str) -> Optional[dict]: + try: + header_b64, payload_b64, sig_b64 = token.split(".") + sign_input = f"{header_b64}.{payload_b64}".encode() + expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()) + if not hmac.compare_digest(expected, sig_b64): + return None + pad = '=' * (-len(payload_b64) % 4) + payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad)) + if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()): + return None + return payload + except Exception: + return None + + +def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row: + if not authorization or not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="missing token") + token = authorization.split(" ", 1)[1].strip() + payload = parse_token(token) + if not payload: + raise HTTPException(status_code=401, detail="invalid token") + conn = get_conn() + user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone() + conn.close() + if not user: + raise HTTPException(status_code=401, detail="user not found") + return user + + +class RegisterReq(BaseModel): + email: EmailStr + password: str + + +class LoginReq(BaseModel): + email: EmailStr + password: str + + +class BindDiscordReq(BaseModel): + discord_id: str + + +@router.post("/auth/register") +def register(body: RegisterReq): + ensure_tables() + conn = get_conn() + try: + pwd_hash = hash_password(body.password) + cur = conn.execute( + "INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)", + (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), + ) + user_id = cur.lastrowid + conn.execute( + "INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", + (user_id,), + ) + conn.commit() + except sqlite3.IntegrityError: + conn.close() + raise HTTPException(status_code=400, detail="email already registered") + + user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone() + conn.close() + return dict(user) + + +@router.post("/auth/login") +def login(body: LoginReq): + ensure_tables() + conn = get_conn() + user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone() + conn.close() + if not user or not verify_password(body.password, user["password_hash"]): + raise HTTPException(status_code=401, detail="invalid credentials") + token = create_token(user["id"], user["email"]) + return {"token": token, "token_type": "bearer"} + + +@router.post("/user/bind-discord") +def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)): + user = get_user_from_auth(authorization) + conn = get_conn() + conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"])) + conn.commit() + updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone() + conn.close() + return dict(updated) + + +@router.get("/user/me") +def me(authorization: Optional[str] = Header(default=None)): + user = get_user_from_auth(authorization) + conn = get_conn() + sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone() + conn.close() + out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]} + out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None} + return out diff --git a/backend/subscriptions.py b/backend/subscriptions.py new file mode 100644 index 0000000..5916719 --- /dev/null +++ b/backend/subscriptions.py @@ -0,0 +1,23 @@ +import os +import sqlite3 +from fastapi import APIRouter + +DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") + +router = APIRouter(prefix="/api", tags=["signals"]) + + +def get_conn(): + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + return conn + + +@router.get("/signals/history") +def signals_history(): + conn = get_conn() + rows = conn.execute( + "SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY id DESC LIMIT 50" + ).fetchall() + conn.close() + return {"items": [dict(r) for r in rows]}