feat: add auth and subscriptions backend modules
This commit is contained in:
parent
93043009ac
commit
cf531d8c44
196
backend/auth.py
Normal file
196
backend/auth.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import hashlib
|
||||||
|
import secrets
|
||||||
|
import hmac
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Header
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
|
||||||
|
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
|
||||||
|
JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me")
|
||||||
|
JWT_EXPIRE_HOURS = 24 * 7
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_conn():
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tables():
|
||||||
|
conn = get_conn()
|
||||||
|
cur = conn.cursor()
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
email TEXT UNIQUE NOT NULL,
|
||||||
|
password_hash TEXT NOT NULL,
|
||||||
|
discord_id TEXT,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||||
|
user_id INTEGER PRIMARY KEY,
|
||||||
|
tier TEXT NOT NULL DEFAULT 'free',
|
||||||
|
expires_at TEXT,
|
||||||
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS signal_logs (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
rate REAL NOT NULL,
|
||||||
|
annualized REAL NOT NULL,
|
||||||
|
sent_at TEXT NOT NULL,
|
||||||
|
message TEXT NOT NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
salt = secrets.token_hex(16)
|
||||||
|
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
||||||
|
return f"{salt}${digest}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, stored: str) -> bool:
|
||||||
|
try:
|
||||||
|
salt, digest = stored.split("$", 1)
|
||||||
|
candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
||||||
|
return hmac.compare_digest(candidate, digest)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def b64url(data: bytes) -> str:
|
||||||
|
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||||||
|
|
||||||
|
|
||||||
|
def create_token(user_id: int, email: str) -> str:
|
||||||
|
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
|
||||||
|
exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp())
|
||||||
|
payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode())
|
||||||
|
sign_input = f"{header}.{payload}".encode()
|
||||||
|
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
|
||||||
|
return f"{header}.{payload}.{b64url(signature)}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_token(token: str) -> Optional[dict]:
|
||||||
|
try:
|
||||||
|
header_b64, payload_b64, sig_b64 = token.split(".")
|
||||||
|
sign_input = f"{header_b64}.{payload_b64}".encode()
|
||||||
|
expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest())
|
||||||
|
if not hmac.compare_digest(expected, sig_b64):
|
||||||
|
return None
|
||||||
|
pad = '=' * (-len(payload_b64) % 4)
|
||||||
|
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad))
|
||||||
|
if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()):
|
||||||
|
return None
|
||||||
|
return payload
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row:
|
||||||
|
if not authorization or not authorization.startswith("Bearer "):
|
||||||
|
raise HTTPException(status_code=401, detail="missing token")
|
||||||
|
token = authorization.split(" ", 1)[1].strip()
|
||||||
|
payload = parse_token(token)
|
||||||
|
if not payload:
|
||||||
|
raise HTTPException(status_code=401, detail="invalid token")
|
||||||
|
conn = get_conn()
|
||||||
|
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
|
||||||
|
conn.close()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=401, detail="user not found")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterReq(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class LoginReq(BaseModel):
|
||||||
|
email: EmailStr
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class BindDiscordReq(BaseModel):
|
||||||
|
discord_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/auth/register")
|
||||||
|
def register(body: RegisterReq):
|
||||||
|
ensure_tables()
|
||||||
|
conn = get_conn()
|
||||||
|
try:
|
||||||
|
pwd_hash = hash_password(body.password)
|
||||||
|
cur = conn.execute(
|
||||||
|
"INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)",
|
||||||
|
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
||||||
|
)
|
||||||
|
user_id = cur.lastrowid
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
except sqlite3.IntegrityError:
|
||||||
|
conn.close()
|
||||||
|
raise HTTPException(status_code=400, detail="email already registered")
|
||||||
|
|
||||||
|
user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone()
|
||||||
|
conn.close()
|
||||||
|
return dict(user)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/auth/login")
|
||||||
|
def login(body: LoginReq):
|
||||||
|
ensure_tables()
|
||||||
|
conn = get_conn()
|
||||||
|
user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone()
|
||||||
|
conn.close()
|
||||||
|
if not user or not verify_password(body.password, user["password_hash"]):
|
||||||
|
raise HTTPException(status_code=401, detail="invalid credentials")
|
||||||
|
token = create_token(user["id"], user["email"])
|
||||||
|
return {"token": token, "token_type": "bearer"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/user/bind-discord")
|
||||||
|
def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)):
|
||||||
|
user = get_user_from_auth(authorization)
|
||||||
|
conn = get_conn()
|
||||||
|
conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"]))
|
||||||
|
conn.commit()
|
||||||
|
updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone()
|
||||||
|
conn.close()
|
||||||
|
return dict(updated)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/me")
|
||||||
|
def me(authorization: Optional[str] = Header(default=None)):
|
||||||
|
user = get_user_from_auth(authorization)
|
||||||
|
conn = get_conn()
|
||||||
|
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
|
||||||
|
conn.close()
|
||||||
|
out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]}
|
||||||
|
out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None}
|
||||||
|
return out
|
||||||
23
backend/subscriptions.py
Normal file
23
backend/subscriptions.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["signals"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_conn():
|
||||||
|
conn = sqlite3.connect(DB_PATH)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/signals/history")
|
||||||
|
def signals_history():
|
||||||
|
conn = get_conn()
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY id DESC LIMIT 50"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
return {"items": [dict(r) for r in rows]}
|
||||||
Loading…
Reference in New Issue
Block a user