P0 issues annotated (critical, must fix before live trading):
- signal_engine.py: cooldown blocks reverse-signal position close
- paper_monitor.py + signal_engine.py: pnl_r 2x inflated for TP scenarios
- signal_engine.py: entry price uses 30min VWAP instead of real-time price
- paper_monitor.py + signal_engine.py: concurrent write race on paper_trades
P1 issues annotated (long-term stability):
- db.py: ensure_partitions uses timedelta(30d) causing missed monthly partitions
- signal_engine.py: float precision drift in buy_vol/sell_vol accumulation
- market_data_collector.py: single bare connection with no reconnect logic
- db.py: get_sync_pool initialization not thread-safe
- signal_engine.py: recent_large_trades deque has no maxlen
P2/P3 issues annotated across backend and frontend:
- coinbase_premium KeyError for XRP/SOL symbols
- liquidation_collector: redundant elif condition in aggregation logic
- auth.py: JWT secret hardcoded default, login rate-limit absent
- Frontend: concurrent refresh token race, AuthContext not synced on failure
- Frontend: universal catch{} swallows all API errors silently
- Frontend: serial API requests in LatestSignals, market-indicators over-polling
docs/REVIEW.md: comprehensive audit report with all 34 issues (P0×4, P1×5,
P2×6, P3×4 backend + FE-P1×4, FE-P2×8, FE-P3×3 frontend), fix suggestions
and prioritized remediation roadmap.
397 lines
15 KiB
Python
397 lines
15 KiB
Python
import os
|
||
import hashlib
|
||
import secrets
|
||
import hmac
|
||
import base64
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, HTTPException, Header, Depends, Request
|
||
from pydantic import BaseModel, EmailStr
|
||
|
||
from db import get_sync_conn
|
||
|
||
# [REVIEW] P3 | JWT 密钥硬编码默认值,若未设置环境变量则使用明文密钥
|
||
# 任何能读到此文件的人均可伪造合法的 JWT token,获取所有用户权限
|
||
# 修复:移除默认值,改为 os.getenv("JWT_SECRET") 并在启动时校验非空
|
||
# 部署:server 上必须设置 export JWT_SECRET=<随机256位密钥>
|
||
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
|
||
ACCESS_TOKEN_HOURS = 24
|
||
REFRESH_TOKEN_DAYS = 7
|
||
|
||
router = APIRouter(prefix="/api", tags=["auth"])
|
||
|
||
|
||
# ─── PG Schema ───────────────────────────────────────────────
|
||
|
||
AUTH_SCHEMA = """
|
||
CREATE TABLE IF NOT EXISTS users (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
email TEXT UNIQUE NOT NULL,
|
||
password_hash TEXT NOT NULL,
|
||
discord_id TEXT,
|
||
role TEXT NOT NULL DEFAULT 'user',
|
||
banned INTEGER NOT NULL DEFAULT 0,
|
||
created_at TEXT NOT NULL
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||
user_id BIGINT PRIMARY KEY REFERENCES users(id),
|
||
tier TEXT NOT NULL DEFAULT 'free',
|
||
expires_at TEXT
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS invite_codes (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
code TEXT UNIQUE NOT NULL,
|
||
created_by INTEGER,
|
||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||
used_count INTEGER NOT NULL DEFAULT 0,
|
||
status TEXT NOT NULL DEFAULT 'active',
|
||
expires_at TEXT,
|
||
created_at TEXT DEFAULT (NOW()::TEXT)
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS invite_usage (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
code_id BIGINT NOT NULL REFERENCES invite_codes(id),
|
||
user_id BIGINT NOT NULL REFERENCES users(id),
|
||
used_at TEXT DEFAULT (NOW()::TEXT)
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
user_id BIGINT NOT NULL REFERENCES users(id),
|
||
token TEXT UNIQUE NOT NULL,
|
||
expires_at TEXT NOT NULL,
|
||
revoked INTEGER NOT NULL DEFAULT 0,
|
||
created_at TEXT DEFAULT (NOW()::TEXT)
|
||
);
|
||
"""
|
||
|
||
|
||
def ensure_tables():
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
for stmt in AUTH_SCHEMA.split(";"):
|
||
stmt = stmt.strip()
|
||
if stmt:
|
||
try:
|
||
cur.execute(stmt)
|
||
except Exception:
|
||
conn.rollback()
|
||
continue
|
||
conn.commit()
|
||
|
||
|
||
# ─── DB helper ───────────────────────────────────────────────
|
||
|
||
def _fetchone(sql, params=None):
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql, params or ())
|
||
row = cur.fetchone()
|
||
if not row:
|
||
return None
|
||
cols = [desc[0] for desc in cur.description]
|
||
return dict(zip(cols, row))
|
||
|
||
|
||
def _fetchall(sql, params=None):
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql, params or ())
|
||
cols = [desc[0] for desc in cur.description]
|
||
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
||
|
||
|
||
def _execute(sql, params=None):
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql, params or ())
|
||
conn.commit()
|
||
try:
|
||
return cur.fetchone()
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
# ─── Password utils ──────────────────────────────────────────
|
||
|
||
def hash_password(password: str) -> str:
|
||
salt = secrets.token_hex(16)
|
||
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
||
return f"{salt}${digest}"
|
||
|
||
|
||
def verify_password(password: str, stored: str) -> bool:
|
||
try:
|
||
salt, digest = stored.split("$", 1)
|
||
candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
||
return hmac.compare_digest(candidate, digest)
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
# ─── JWT utils ───────────────────────────────────────────────
|
||
|
||
def b64url(data: bytes) -> str:
|
||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||
|
||
|
||
def create_access_token(user_id: int, email: str, role: str) -> str:
|
||
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
|
||
exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
|
||
payload = b64url(json.dumps({
|
||
"sub": user_id, "email": email, "role": role,
|
||
"exp": exp, "type": "access"
|
||
}, separators=(",", ":")).encode())
|
||
sign_input = f"{header}.{payload}".encode()
|
||
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
|
||
return f"{header}.{payload}.{b64url(signature)}"
|
||
|
||
|
||
def create_refresh_token(user_id: int) -> str:
|
||
token = secrets.token_urlsafe(48)
|
||
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
|
||
_execute(
|
||
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
|
||
(user_id, token, expires_at)
|
||
)
|
||
return token
|
||
|
||
|
||
def parse_token(token: str) -> Optional[dict]:
|
||
try:
|
||
header_b64, payload_b64, sig_b64 = token.split(".")
|
||
sign_input = f"{header_b64}.{payload_b64}".encode()
|
||
expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest())
|
||
if not hmac.compare_digest(expected, sig_b64):
|
||
return None
|
||
pad = '=' * (-len(payload_b64) % 4)
|
||
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad))
|
||
if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()):
|
||
return None
|
||
return payload
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
# ─── Auth dependency ─────────────────────────────────────────
|
||
|
||
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
|
||
if not authorization or not authorization.startswith("Bearer "):
|
||
raise HTTPException(status_code=401, detail="missing token")
|
||
token = authorization.split(" ", 1)[1].strip()
|
||
payload = parse_token(token)
|
||
if not payload:
|
||
raise HTTPException(status_code=401, detail="invalid or expired token")
|
||
if payload.get("type") != "access":
|
||
raise HTTPException(status_code=401, detail="invalid token type")
|
||
user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
|
||
if not user:
|
||
raise HTTPException(status_code=401, detail="user not found")
|
||
if user["banned"]:
|
||
raise HTTPException(status_code=403, detail="account banned")
|
||
return user
|
||
|
||
|
||
def require_admin(user: dict = Depends(get_current_user)) -> dict:
|
||
if user.get("role") != "admin":
|
||
raise HTTPException(status_code=403, detail="admin required")
|
||
return user
|
||
|
||
|
||
# ─── Request models ──────────────────────────────────────────
|
||
|
||
class RegisterReq(BaseModel):
|
||
email: EmailStr
|
||
password: str
|
||
invite_code: str
|
||
|
||
|
||
class LoginReq(BaseModel):
|
||
email: EmailStr
|
||
password: str
|
||
|
||
|
||
class RefreshReq(BaseModel):
|
||
refresh_token: str
|
||
|
||
|
||
class GenInviteReq(BaseModel):
|
||
count: int = 1
|
||
max_uses: int = 1
|
||
|
||
|
||
class BanUserReq(BaseModel):
|
||
banned: bool = True
|
||
|
||
|
||
# ─── Auth routes ─────────────────────────────────────────────
|
||
|
||
@router.post("/auth/register")
|
||
def register(body: RegisterReq):
|
||
ensure_tables()
|
||
invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
|
||
if not invite:
|
||
raise HTTPException(status_code=400, detail="invalid invite code")
|
||
if invite["status"] != "active":
|
||
raise HTTPException(status_code=400, detail="invite code disabled")
|
||
if invite["used_count"] >= invite["max_uses"]:
|
||
raise HTTPException(status_code=400, detail="invite code exhausted")
|
||
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
|
||
raise HTTPException(status_code=400, detail="invite code expired")
|
||
|
||
pwd_hash = hash_password(body.password)
|
||
try:
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
|
||
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
||
)
|
||
user_id = cur.fetchone()[0]
|
||
cur.execute(
|
||
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
|
||
(user_id,),
|
||
)
|
||
cur.execute(
|
||
"INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
|
||
(invite["id"], user_id),
|
||
)
|
||
cur.execute(
|
||
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
|
||
(invite["id"],),
|
||
)
|
||
new_count = invite["used_count"] + 1
|
||
if new_count >= invite["max_uses"]:
|
||
cur.execute(
|
||
"UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
|
||
(invite["id"],),
|
||
)
|
||
conn.commit()
|
||
except Exception as e:
|
||
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
|
||
raise HTTPException(status_code=400, detail="email already registered")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
|
||
access = create_access_token(user["id"], user["email"], user["role"])
|
||
refresh = create_refresh_token(user["id"])
|
||
return {
|
||
"access_token": access,
|
||
"refresh_token": refresh,
|
||
"token_type": "bearer",
|
||
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
||
}
|
||
|
||
|
||
# [REVIEW] P3 | 登录接口无频率限制,可被暴力破解
|
||
# 建议:接入 slowapi 或 redis 计数器,同一IP 60秒内超过10次返回429
|
||
@router.post("/auth/login")
|
||
def login(body: LoginReq):
|
||
ensure_tables()
|
||
user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
|
||
if not user or not verify_password(body.password, user["password_hash"]):
|
||
raise HTTPException(status_code=401, detail="invalid credentials")
|
||
if user["banned"]:
|
||
raise HTTPException(status_code=403, detail="account banned")
|
||
access = create_access_token(user["id"], user["email"], user["role"])
|
||
refresh = create_refresh_token(user["id"])
|
||
return {
|
||
"access_token": access,
|
||
"refresh_token": refresh,
|
||
"token_type": "bearer",
|
||
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
||
}
|
||
|
||
|
||
@router.post("/auth/refresh")
|
||
def refresh_token(body: RefreshReq):
|
||
# [REVIEW] P3 | refresh token 刷新非原子操作,存在并发竞态
|
||
# SELECT(revoked=0) 和 UPDATE(revoked=1) 之间有时间窗口
|
||
# 两个并发请求可能同时通过 SELECT 校验,都获得新 token(token 复制攻击)
|
||
# 修复:改用原子操作
|
||
# UPDATE refresh_tokens SET revoked=1 WHERE token=%s AND revoked=0 RETURNING user_id
|
||
# 若无返回行则 token 已失效
|
||
row = _fetchone(
|
||
"SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
|
||
)
|
||
if not row:
|
||
raise HTTPException(status_code=401, detail="invalid refresh token")
|
||
if row["expires_at"] < datetime.utcnow().isoformat():
|
||
raise HTTPException(status_code=401, detail="refresh token expired")
|
||
user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
|
||
if not user or user["banned"]:
|
||
raise HTTPException(status_code=403, detail="account unavailable")
|
||
_execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
|
||
access = create_access_token(user["id"], user["email"], user["role"])
|
||
new_refresh = create_refresh_token(user["id"])
|
||
return {
|
||
"access_token": access,
|
||
"refresh_token": new_refresh,
|
||
"token_type": "bearer",
|
||
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||
}
|
||
|
||
|
||
@router.get("/auth/me")
|
||
def me(user: dict = Depends(get_current_user)):
|
||
sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
|
||
return {
|
||
"id": user["id"], "email": user["email"], "role": user["role"],
|
||
"discord_id": user.get("discord_id"),
|
||
"created_at": user["created_at"],
|
||
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
|
||
}
|
||
|
||
|
||
# ─── Admin routes ────────────────────────────────────────────
|
||
|
||
@router.post("/admin/invite-codes")
|
||
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
|
||
codes = []
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
for _ in range(body.count):
|
||
code = secrets.token_urlsafe(6)[:8].upper()
|
||
cur.execute(
|
||
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
|
||
(code, admin["id"], body.max_uses),
|
||
)
|
||
codes.append(code)
|
||
conn.commit()
|
||
return {"codes": codes}
|
||
|
||
|
||
@router.get("/admin/invite-codes")
|
||
def list_invite_codes(admin: dict = Depends(require_admin)):
|
||
rows = _fetchall(
|
||
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
|
||
)
|
||
return {"items": rows}
|
||
|
||
|
||
@router.delete("/admin/invite-codes/{code_id}")
|
||
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
|
||
_execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
|
||
return {"ok": True}
|
||
|
||
|
||
@router.get("/admin/users")
|
||
def list_users(admin: dict = Depends(require_admin)):
|
||
rows = _fetchall(
|
||
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
|
||
)
|
||
return {"items": rows}
|
||
|
||
|
||
@router.put("/admin/users/{user_id}/ban")
|
||
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
|
||
_execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
|
||
return {"ok": True, "user_id": user_id, "banned": body.banned}
|