arbitrage-engine/backend/auth.py
dev-worker 27a51b4d19 fix: P0第二轮修复 — JWT安全/DB密码/SL紧急平仓reduceOnly/TP1状态守卫/超时精度/跨策略去重 + 硬编码消除
P0-1: JWT_SECRET生产环境强制配置,测试环境保留默认
P0-2: DB密码生产环境强制从env读,测试环境保留fallback
P0-3: SL三次失败→查真实持仓→reduceOnly平仓→校验结果→写event
P0-4: TP1后SL重挂失败则不推进tp1_hit状态,continue等下轮重试
P0-5: 超时自动平仓用SYMBOL_QTY_PRECISION格式化+校验结果
P0-6: 同币种去重改为不区分策略(币安单向模式共享净仓位)
P1-1: 手续费窗口entry_ts-200→+200(避免纳入开仓前成交)
额外: 模拟盘*200和实盘*2硬编码→从配置动态读取
2026-03-02 16:11:43 +00:00

389 lines
14 KiB
Python

import os
import hashlib
import secrets
import hmac
import base64
import json
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, HTTPException, Header, Depends, Request
from pydantic import BaseModel, EmailStr
from db import get_sync_conn
_TRADE_ENV = os.getenv("TRADE_ENV", "testnet")
_jwt_default = "arb-engine-jwt-secret-v2-2026" if _TRADE_ENV == "testnet" else None
JWT_SECRET = os.getenv("JWT_SECRET") or _jwt_default
if not JWT_SECRET or len(JWT_SECRET) < 32:
raise RuntimeError("JWT_SECRET 未配置或长度不足(>=32),生产环境必须设置环境变量")
ACCESS_TOKEN_HOURS = 24
REFRESH_TOKEN_DAYS = 7
router = APIRouter(prefix="/api", tags=["auth"])
# ─── PG Schema ───────────────────────────────────────────────
AUTH_SCHEMA = """
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
discord_id TEXT,
role TEXT NOT NULL DEFAULT 'user',
banned INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS subscriptions (
user_id BIGINT PRIMARY KEY REFERENCES users(id),
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT
);
CREATE TABLE IF NOT EXISTS invite_codes (
id BIGSERIAL PRIMARY KEY,
code TEXT UNIQUE NOT NULL,
created_by INTEGER,
max_uses INTEGER NOT NULL DEFAULT 1,
used_count INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
expires_at TEXT,
created_at TEXT DEFAULT (NOW()::TEXT)
);
CREATE TABLE IF NOT EXISTS invite_usage (
id BIGSERIAL PRIMARY KEY,
code_id BIGINT NOT NULL REFERENCES invite_codes(id),
user_id BIGINT NOT NULL REFERENCES users(id),
used_at TEXT DEFAULT (NOW()::TEXT)
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id),
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT DEFAULT (NOW()::TEXT)
);
"""
def ensure_tables():
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in AUTH_SCHEMA.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception:
conn.rollback()
continue
conn.commit()
# ─── DB helper ───────────────────────────────────────────────
def _fetchone(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
row = cur.fetchone()
if not row:
return None
cols = [desc[0] for desc in cur.description]
return dict(zip(cols, row))
def _fetchall(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
def _execute(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
conn.commit()
try:
return cur.fetchone()
except Exception:
return None
# ─── Password utils ──────────────────────────────────────────
def hash_password(password: str) -> str:
salt = secrets.token_hex(16)
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
return f"{salt}${digest}"
def verify_password(password: str, stored: str) -> bool:
try:
salt, digest = stored.split("$", 1)
candidate = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
return hmac.compare_digest(candidate, digest)
except Exception:
return False
# ─── JWT utils ───────────────────────────────────────────────
def b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
def create_access_token(user_id: int, email: str, role: str) -> str:
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
payload = b64url(json.dumps({
"sub": user_id, "email": email, "role": role,
"exp": exp, "type": "access"
}, separators=(",", ":")).encode())
sign_input = f"{header}.{payload}".encode()
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
return f"{header}.{payload}.{b64url(signature)}"
def create_refresh_token(user_id: int) -> str:
token = secrets.token_urlsafe(48)
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
_execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
(user_id, token, expires_at)
)
return token
def parse_token(token: str) -> Optional[dict]:
try:
header_b64, payload_b64, sig_b64 = token.split(".")
sign_input = f"{header_b64}.{payload_b64}".encode()
expected = b64url(hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest())
if not hmac.compare_digest(expected, sig_b64):
return None
pad = '=' * (-len(payload_b64) % 4)
payload = json.loads(base64.urlsafe_b64decode(payload_b64 + pad))
if int(payload.get("exp", 0)) < int(datetime.utcnow().timestamp()):
return None
return payload
except Exception:
return None
# ─── Auth dependency ─────────────────────────────────────────
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing token")
token = authorization.split(" ", 1)[1].strip()
payload = parse_token(token)
if not payload:
raise HTTPException(status_code=401, detail="invalid or expired token")
if payload.get("type") != "access":
raise HTTPException(status_code=401, detail="invalid token type")
user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
if not user:
raise HTTPException(status_code=401, detail="user not found")
if user["banned"]:
raise HTTPException(status_code=403, detail="account banned")
return user
def require_admin(user: dict = Depends(get_current_user)) -> dict:
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="admin required")
return user
# ─── Request models ──────────────────────────────────────────
class RegisterReq(BaseModel):
email: EmailStr
password: str
invite_code: str
class LoginReq(BaseModel):
email: EmailStr
password: str
class RefreshReq(BaseModel):
refresh_token: str
class GenInviteReq(BaseModel):
count: int = 1
max_uses: int = 1
class BanUserReq(BaseModel):
banned: bool = True
# ─── Auth routes ─────────────────────────────────────────────
@router.post("/auth/register")
def register(body: RegisterReq):
ensure_tables()
invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
if not invite:
raise HTTPException(status_code=400, detail="invalid invite code")
if invite["status"] != "active":
raise HTTPException(status_code=400, detail="invite code disabled")
if invite["used_count"] >= invite["max_uses"]:
raise HTTPException(status_code=400, detail="invite code exhausted")
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
raise HTTPException(status_code=400, detail="invite code expired")
pwd_hash = hash_password(body.password)
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
)
user_id = cur.fetchone()[0]
cur.execute(
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
(user_id,),
)
cur.execute(
"INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
(invite["id"], user_id),
)
cur.execute(
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
(invite["id"],),
)
new_count = invite["used_count"] + 1
if new_count >= invite["max_uses"]:
cur.execute(
"UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
(invite["id"],),
)
conn.commit()
except Exception as e:
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
raise HTTPException(status_code=400, detail="email already registered")
raise HTTPException(status_code=500, detail=str(e))
user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
}
@router.post("/auth/login")
def login(body: LoginReq):
ensure_tables()
user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
if not user or not verify_password(body.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="invalid credentials")
if user["banned"]:
raise HTTPException(status_code=403, detail="account banned")
access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
}
@router.post("/auth/refresh")
def refresh_token(body: RefreshReq):
row = _fetchone(
"SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
)
if not row:
raise HTTPException(status_code=401, detail="invalid refresh token")
if row["expires_at"] < datetime.utcnow().isoformat():
raise HTTPException(status_code=401, detail="refresh token expired")
user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
if not user or user["banned"]:
raise HTTPException(status_code=403, detail="account unavailable")
_execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
access = create_access_token(user["id"], user["email"], user["role"])
new_refresh = create_refresh_token(user["id"])
return {
"access_token": access,
"refresh_token": new_refresh,
"token_type": "bearer",
"expires_in": ACCESS_TOKEN_HOURS * 3600,
}
@router.get("/auth/me")
def me(user: dict = Depends(get_current_user)):
sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
return {
"id": user["id"], "email": user["email"], "role": user["role"],
"discord_id": user.get("discord_id"),
"created_at": user["created_at"],
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
}
# ─── Admin routes ────────────────────────────────────────────
@router.post("/admin/invite-codes")
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
codes = []
with get_sync_conn() as conn:
with conn.cursor() as cur:
for _ in range(body.count):
code = secrets.token_urlsafe(6)[:8].upper()
cur.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
(code, admin["id"], body.max_uses),
)
codes.append(code)
conn.commit()
return {"codes": codes}
@router.get("/admin/invite-codes")
def list_invite_codes(admin: dict = Depends(require_admin)):
rows = _fetchall(
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
)
return {"items": rows}
@router.delete("/admin/invite-codes/{code_id}")
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
_execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
return {"ok": True}
@router.get("/admin/users")
def list_users(admin: dict = Depends(require_admin)):
rows = _fetchall(
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
)
return {"items": rows}
@router.put("/admin/users/{user_id}/ban")
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
_execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
return {"ok": True, "user_id": user_id, "banned": body.banned}