arbitrage-engine/backend/auth.py
fanziqi ad60a53262 review: add code audit annotations and REVIEW.md for v5.1
P0 issues annotated (critical, must fix before live trading):
- signal_engine.py: cooldown blocks reverse-signal position close
- paper_monitor.py + signal_engine.py: pnl_r 2x inflated for TP scenarios
- signal_engine.py: entry price uses 30min VWAP instead of real-time price
- paper_monitor.py + signal_engine.py: concurrent write race on paper_trades

P1 issues annotated (long-term stability):
- db.py: ensure_partitions uses timedelta(30d) causing missed monthly partitions
- signal_engine.py: float precision drift in buy_vol/sell_vol accumulation
- market_data_collector.py: single bare connection with no reconnect logic
- db.py: get_sync_pool initialization not thread-safe
- signal_engine.py: recent_large_trades deque has no maxlen

P2/P3 issues annotated across backend and frontend:
- coinbase_premium KeyError for XRP/SOL symbols
- liquidation_collector: redundant elif condition in aggregation logic
- auth.py: JWT secret hardcoded default, login rate-limit absent
- Frontend: concurrent refresh token race, AuthContext not synced on failure
- Frontend: universal catch{} swallows all API errors silently
- Frontend: serial API requests in LatestSignals, market-indicators over-polling

docs/REVIEW.md: comprehensive audit report with all 34 issues (P0×4, P1×5,
P2×6, P3×4 backend + FE-P1×4, FE-P2×8, FE-P3×3 frontend), fix suggestions
and prioritized remediation roadmap.
2026-03-01 17:14:52 +08:00

397 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import hashlib
import secrets
import hmac
import base64
import json
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, HTTPException, Header, Depends, Request
from pydantic import BaseModel, EmailStr
from db import get_sync_conn
# [REVIEW] P3 | JWT 密钥硬编码默认值,若未设置环境变量则使用明文密钥
# 任何能读到此文件的人均可伪造合法的 JWT token获取所有用户权限
# 修复:移除默认值,改为 os.getenv("JWT_SECRET") 并在启动时校验非空
# 部署server 上必须设置 export JWT_SECRET=<随机256位密钥>
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
ACCESS_TOKEN_HOURS = 24
REFRESH_TOKEN_DAYS = 7
router = APIRouter(prefix="/api", tags=["auth"])
# ─── PG Schema ───────────────────────────────────────────────
AUTH_SCHEMA = """
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
discord_id TEXT,
role TEXT NOT NULL DEFAULT 'user',
banned INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS subscriptions (
user_id BIGINT PRIMARY KEY REFERENCES users(id),
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT
);
CREATE TABLE IF NOT EXISTS invite_codes (
id BIGSERIAL PRIMARY KEY,
code TEXT UNIQUE NOT NULL,
created_by INTEGER,
max_uses INTEGER NOT NULL DEFAULT 1,
used_count INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
expires_at TEXT,
created_at TEXT DEFAULT (NOW()::TEXT)
);
CREATE TABLE IF NOT EXISTS invite_usage (
id BIGSERIAL PRIMARY KEY,
code_id BIGINT NOT NULL REFERENCES invite_codes(id),
user_id BIGINT NOT NULL REFERENCES users(id),
used_at TEXT DEFAULT (NOW()::TEXT)
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id),
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT DEFAULT (NOW()::TEXT)
);
"""
def ensure_tables():
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in AUTH_SCHEMA.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception:
conn.rollback()
continue
conn.commit()
# ─── DB helper ───────────────────────────────────────────────
def _fetchone(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
row = cur.fetchone()
if not row:
return None
cols = [desc[0] for desc in cur.description]
return dict(zip(cols, row))
def _fetchall(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
def _execute(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
conn.commit()
try:
return cur.fetchone()
except Exception:
return None
# ─── Password utils ──────────────────────────────────────────
def hash_password(password: str) -> str:
salt = secrets.token_hex(16)
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
return f"{salt}${digest}"
def verify_password(password: str, stored: str) -> bool:
try:
salt, digest = stored.split("$", 1)
candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
return hmac.compare_digest(candidate, digest)
except Exception:
return False
# ─── JWT utils ───────────────────────────────────────────────
def b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def create_access_token(user_id: int, email: str, role: str) -> str:
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
payload = b64url(json.dumps({
"sub": user_id, "email": email, "role": role,
"exp": exp, "type": "access"
}, separators=(",", ":")).encode())
sign_input = f"{header}.{payload}".encode()
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
return f"{header}.{payload}.{b64url(signature)}"
def create_refresh_token(user_id: int) -> str:
token = secrets.token_urlsafe(48)
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
_execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
(user_id, token, expires_at)
)
return token
def parse_token(token: str) -> Optional[dict]:
try:
header_b64, payload_b64, sig_b64 = token.split(".")
sign_input = f"{header_b64}.{payload_b64}".encode()
expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest())
if not hmac.compare_digest(expected, sig_b64):
return None
pad = '=' * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad))
if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()):
return None
return payload
except Exception:
return None
# ─── Auth dependency ─────────────────────────────────────────
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing token")
token = authorization.split(" ", 1)[1].strip()
payload = parse_token(token)
if not payload:
raise HTTPException(status_code=401, detail="invalid or expired token")
if payload.get("type") != "access":
raise HTTPException(status_code=401, detail="invalid token type")
user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
if not user:
raise HTTPException(status_code=401, detail="user not found")
if user["banned"]:
raise HTTPException(status_code=403, detail="account banned")
return user
def require_admin(user: dict = Depends(get_current_user)) -> dict:
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="admin required")
return user
# ─── Request models ──────────────────────────────────────────
class RegisterReq(BaseModel):
email: EmailStr
password: str
invite_code: str
class LoginReq(BaseModel):
email: EmailStr
password: str
class RefreshReq(BaseModel):
refresh_token: str
class GenInviteReq(BaseModel):
count: int = 1
max_uses: int = 1
class BanUserReq(BaseModel):
banned: bool = True
# ─── Auth routes ─────────────────────────────────────────────
@router.post("/auth/register")
def register(body: RegisterReq):
ensure_tables()
invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
if not invite:
raise HTTPException(status_code=400, detail="invalid invite code")
if invite["status"] != "active":
raise HTTPException(status_code=400, detail="invite code disabled")
if invite["used_count"] >= invite["max_uses"]:
raise HTTPException(status_code=400, detail="invite code exhausted")
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
raise HTTPException(status_code=400, detail="invite code expired")
pwd_hash = hash_password(body.password)
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
)
user_id = cur.fetchone()[0]
cur.execute(
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
(user_id,),
)
cur.execute(
"INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
(invite["id"], user_id),
)
cur.execute(
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
(invite["id"],),
)
new_count = invite["used_count"] + 1
if new_count >= invite["max_uses"]:
cur.execute(
"UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
(invite["id"],),
)
conn.commit()
except Exception as e:
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
raise HTTPException(status_code=400, detail="email already registered")
raise HTTPException(status_code=500, detail=str(e))
user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
}
# [REVIEW] P3 | 登录接口无频率限制,可被暴力破解
# 建议:接入 slowapi 或 redis 计数器同一IP 60秒内超过10次返回429
@router.post("/auth/login")
def login(body: LoginReq):
ensure_tables()
user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
if not user or not verify_password(body.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="invalid credentials")
if user["banned"]:
raise HTTPException(status_code=403, detail="account banned")
access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
}
@router.post("/auth/refresh")
def refresh_token(body: RefreshReq):
# [REVIEW] P3 | refresh token 刷新非原子操作,存在并发竞态
# SELECT(revoked=0) 和 UPDATE(revoked=1) 之间有时间窗口
# 两个并发请求可能同时通过 SELECT 校验,都获得新 tokentoken 复制攻击)
# 修复:改用原子操作
# UPDATE refresh_tokens SET revoked=1 WHERE token=%s AND revoked=0 RETURNING user_id
# 若无返回行则 token 已失效
row = _fetchone(
"SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
)
if not row:
raise HTTPException(status_code=401, detail="invalid refresh token")
if row["expires_at"] < datetime.utcnow().isoformat():
raise HTTPException(status_code=401, detail="refresh token expired")
user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
if not user or user["banned"]:
raise HTTPException(status_code=403, detail="account unavailable")
_execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
access = create_access_token(user["id"], user["email"], user["role"])
new_refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": new_refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
}
@router.get("/auth/me")
def me(user: dict = Depends(get_current_user)):
sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
return {
"id": user["id"], "email": user["email"], "role": user["role"],
"discord_id": user.get("discord_id"),
"created_at": user["created_at"],
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
}
# ─── Admin routes ────────────────────────────────────────────
@router.post("/admin/invite-codes")
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
codes = []
with get_sync_conn() as conn:
with conn.cursor() as cur:
for _ in range(body.count):
code = secrets.token_urlsafe(6)[:8].upper()
cur.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
(code, admin["id"], body.max_uses),
)
codes.append(code)
conn.commit()
return {"codes": codes}
@router.get("/admin/invite-codes")
def list_invite_codes(admin: dict = Depends(require_admin)):
rows = _fetchall(
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
)
return {"items": rows}
@router.delete("/admin/invite-codes/{code_id}")
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
_execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
return {"ok": True}
@router.get("/admin/users")
def list_users(admin: dict = Depends(require_admin)):
rows = _fetchall(
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
)
return {"items": rows}
@router.put("/admin/users/{user_id}/ban")
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
_execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
return {"ok": True, "user_id": user_id, "banned": body.banned}