feat: V2.0 auth system - JWT access/refresh, invite codes, route protection, admin CLI, auth gate blur overlay
This commit is contained in:
parent
052e5a0541
commit
1ab228286c
115
backend/admin_cli.py
Normal file
115
backend/admin_cli.py
Normal file
@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Admin CLI for Arbitrage Engine"""
|
||||
import sys, os, sqlite3, secrets, json
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||
|
||||
def get_conn():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def gen_invite(count=1, max_uses=1):
|
||||
conn = get_conn()
|
||||
codes = []
|
||||
for _ in range(count):
|
||||
code = secrets.token_urlsafe(6)[:8].upper()
|
||||
conn.execute(
|
||||
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)",
|
||||
(code, max_uses)
|
||||
)
|
||||
codes.append(code)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
for c in codes:
|
||||
print(f" {c}")
|
||||
print(f"\nGenerated {len(codes)} invite code(s)")
|
||||
|
||||
def list_invites():
|
||||
conn = get_conn()
|
||||
rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall()
|
||||
conn.close()
|
||||
if not rows:
|
||||
print("No invite codes found")
|
||||
return
|
||||
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
|
||||
print("-" * 60)
|
||||
for r in rows:
|
||||
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}")
|
||||
|
||||
def disable_invite(code):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"Disabled invite code: {code}")
|
||||
|
||||
def list_users():
|
||||
conn = get_conn()
|
||||
rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall()
|
||||
conn.close()
|
||||
if not rows:
|
||||
print("No users found")
|
||||
return
|
||||
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
|
||||
print("-" * 72)
|
||||
for r in rows:
|
||||
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}")
|
||||
|
||||
def ban_user(user_id):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"Banned user ID: {user_id}")
|
||||
|
||||
def unban_user(user_id):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"Unbanned user ID: {user_id}")
|
||||
|
||||
def set_admin(user_id):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"Set user {user_id} as admin")
|
||||
|
||||
def usage():
|
||||
print("""Usage: python3 admin_cli.py <command> [args]
|
||||
|
||||
Commands:
|
||||
gen-invite [count] [max_uses] Generate invite codes (default: 1 code, 1 use)
|
||||
list-invites List all invite codes
|
||||
disable-invite <CODE> Disable an invite code
|
||||
list-users List all users
|
||||
ban-user <USER_ID> Ban a user
|
||||
unban-user <USER_ID> Unban a user
|
||||
set-admin <USER_ID> Set user as admin
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
usage()
|
||||
sys.exit(1)
|
||||
cmd = sys.argv[1]
|
||||
if cmd == "gen-invite":
|
||||
count = int(sys.argv[2]) if len(sys.argv) > 2 else 1
|
||||
max_uses = int(sys.argv[3]) if len(sys.argv) > 3 else 1
|
||||
gen_invite(count, max_uses)
|
||||
elif cmd == "list-invites":
|
||||
list_invites()
|
||||
elif cmd == "disable-invite":
|
||||
disable_invite(sys.argv[2])
|
||||
elif cmd == "list-users":
|
||||
list_users()
|
||||
elif cmd == "ban-user":
|
||||
ban_user(int(sys.argv[2]))
|
||||
elif cmd == "unban-user":
|
||||
unban_user(int(sys.argv[2]))
|
||||
elif cmd == "set-admin":
|
||||
set_admin(int(sys.argv[2]))
|
||||
else:
|
||||
usage()
|
||||
297
backend/auth.py
297
backend/auth.py
@ -7,17 +7,21 @@ import base64
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi import APIRouter, HTTPException, Header, Depends, Request
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "arb-secret-change-me")
|
||||
JWT_EXPIRE_HOURS = 24 * 7
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
|
||||
ACCESS_TOKEN_HOURS = 24
|
||||
REFRESH_TOKEN_DAYS = 7
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["auth"])
|
||||
|
||||
|
||||
# ─── DB helpers ───────────────────────────────────────────────
|
||||
|
||||
def get_conn():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
@ -27,29 +31,26 @@ def get_conn():
|
||||
def ensure_tables():
|
||||
conn = get_conn()
|
||||
cur = conn.cursor()
|
||||
cur.execute(
|
||||
"""
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
discord_id TEXT,
|
||||
role TEXT NOT NULL DEFAULT 'user',
|
||||
banned INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
tier TEXT NOT NULL DEFAULT 'free',
|
||||
expires_at TEXT,
|
||||
FOREIGN KEY(user_id) REFERENCES users(id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
cur.execute(
|
||||
"""
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signal_logs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT NOT NULL,
|
||||
@ -58,12 +59,55 @@ def ensure_tables():
|
||||
sent_at TEXT NOT NULL,
|
||||
message TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
code TEXT UNIQUE NOT NULL,
|
||||
created_by INTEGER,
|
||||
max_uses INTEGER NOT NULL DEFAULT 1,
|
||||
used_count INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
expires_at TEXT,
|
||||
created_at TEXT DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS invite_usage (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
code_id INTEGER NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
used_at TEXT DEFAULT (datetime('now')),
|
||||
FOREIGN KEY (code_id) REFERENCES invite_codes(id),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
)
|
||||
""")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
token TEXT UNIQUE NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
revoked INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT DEFAULT (datetime('now')),
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
)
|
||||
""")
|
||||
# migrate: add role/banned columns if missing
|
||||
try:
|
||||
cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'")
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0")
|
||||
except:
|
||||
pass
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
# ─── Password utils ──────────────────────────────────────────
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
salt = secrets.token_hex(16)
|
||||
digest = hashlib.scrypt(password.encode(), salt=salt.encode(), n=2**14, r=8, p=1).hex()
|
||||
@ -79,19 +123,37 @@ def verify_password(password: str, stored: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ─── JWT utils ───────────────────────────────────────────────
|
||||
|
||||
def b64url(data: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()
|
||||
|
||||
|
||||
def create_token(user_id: int, email: str) -> str:
|
||||
def create_access_token(user_id: int, email: str, role: str) -> str:
|
||||
header = b64url(json.dumps({"alg": "HS256", "typ": "JWT"}, separators=(",", ":")).encode())
|
||||
exp = int((datetime.utcnow() + timedelta(hours=JWT_EXPIRE_HOURS)).timestamp())
|
||||
payload = b64url(json.dumps({"sub": user_id, "email": email, "exp": exp}, separators=(",", ":")).encode())
|
||||
exp = int((datetime.utcnow() + timedelta(hours=ACCESS_TOKEN_HOURS)).timestamp())
|
||||
payload = b64url(json.dumps({
|
||||
"sub": user_id, "email": email, "role": role,
|
||||
"exp": exp, "type": "access"
|
||||
}, separators=(",", ":")).encode())
|
||||
sign_input = f"{header}.{payload}".encode()
|
||||
signature = hmac.new(JWT_SECRET.encode(), sign_input, hashlib.sha256).digest()
|
||||
return f"{header}.{payload}.{b64url(signature)}"
|
||||
|
||||
|
||||
def create_refresh_token(user_id: int) -> str:
|
||||
token = secrets.token_urlsafe(48)
|
||||
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
|
||||
conn = get_conn()
|
||||
conn.execute(
|
||||
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)",
|
||||
(user_id, token, expires_at)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return token
|
||||
|
||||
|
||||
def parse_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
header_b64, payload_b64, sig_b64 = token.split(".")
|
||||
@ -108,24 +170,41 @@ def parse_token(token: str) -> Optional[dict]:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_from_auth(authorization: Optional[str]) -> sqlite3.Row:
|
||||
# ─── Auth dependency ─────────────────────────────────────────
|
||||
|
||||
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
|
||||
"""FastAPI dependency: returns user dict or raises 401"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="missing token")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
payload = parse_token(token)
|
||||
if not payload:
|
||||
raise HTTPException(status_code=401, detail="invalid token")
|
||||
raise HTTPException(status_code=401, detail="invalid or expired token")
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(status_code=401, detail="invalid token type")
|
||||
conn = get_conn()
|
||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
|
||||
conn.close()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="user not found")
|
||||
if user["banned"]:
|
||||
raise HTTPException(status_code=403, detail="account banned")
|
||||
return dict(user)
|
||||
|
||||
|
||||
def require_admin(user: dict = Depends(get_current_user)) -> dict:
|
||||
"""FastAPI dependency: requires admin role"""
|
||||
if user.get("role") != "admin":
|
||||
raise HTTPException(status_code=403, detail="admin required")
|
||||
return user
|
||||
|
||||
|
||||
# ─── Request models ──────────────────────────────────────────
|
||||
|
||||
class RegisterReq(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
invite_code: str
|
||||
|
||||
|
||||
class LoginReq(BaseModel):
|
||||
@ -133,18 +212,47 @@ class LoginReq(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
class BindDiscordReq(BaseModel):
|
||||
discord_id: str
|
||||
class RefreshReq(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class GenInviteReq(BaseModel):
|
||||
count: int = 1
|
||||
max_uses: int = 1
|
||||
|
||||
|
||||
class BanUserReq(BaseModel):
|
||||
banned: bool = True
|
||||
|
||||
|
||||
# ─── Auth routes ─────────────────────────────────────────────
|
||||
|
||||
@router.post("/auth/register")
|
||||
def register(body: RegisterReq):
|
||||
ensure_tables()
|
||||
conn = get_conn()
|
||||
|
||||
# verify invite code
|
||||
invite = conn.execute(
|
||||
"SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,)
|
||||
).fetchone()
|
||||
if not invite:
|
||||
conn.close()
|
||||
raise HTTPException(status_code=400, detail="invalid invite code")
|
||||
if invite["status"] != "active":
|
||||
conn.close()
|
||||
raise HTTPException(status_code=400, detail="invite code disabled")
|
||||
if invite["used_count"] >= invite["max_uses"]:
|
||||
conn.close()
|
||||
raise HTTPException(status_code=400, detail="invite code exhausted")
|
||||
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
|
||||
conn.close()
|
||||
raise HTTPException(status_code=400, detail="invite code expired")
|
||||
|
||||
try:
|
||||
pwd_hash = hash_password(body.password)
|
||||
cur = conn.execute(
|
||||
"INSERT INTO users (email, password_hash, created_at) VALUES (?, ?, ?)",
|
||||
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)",
|
||||
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
|
||||
)
|
||||
user_id = cur.lastrowid
|
||||
@ -152,14 +260,39 @@ def register(body: RegisterReq):
|
||||
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
|
||||
(user_id,),
|
||||
)
|
||||
# record invite usage
|
||||
conn.execute(
|
||||
"INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)",
|
||||
(invite["id"], user_id),
|
||||
)
|
||||
conn.execute(
|
||||
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?",
|
||||
(invite["id"],),
|
||||
)
|
||||
# auto-exhaust if max reached
|
||||
new_count = invite["used_count"] + 1
|
||||
if new_count >= invite["max_uses"]:
|
||||
conn.execute(
|
||||
"UPDATE invite_codes SET status = 'exhausted' WHERE id = ?",
|
||||
(invite["id"],),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError:
|
||||
conn.close()
|
||||
raise HTTPException(status_code=400, detail="email already registered")
|
||||
|
||||
user = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user_id,)).fetchone()
|
||||
# issue tokens
|
||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
|
||||
conn.close()
|
||||
return dict(user)
|
||||
access = create_access_token(user["id"], user["email"], user["role"])
|
||||
refresh = create_refresh_token(user["id"])
|
||||
return {
|
||||
"access_token": access,
|
||||
"refresh_token": refresh,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||||
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/auth/login")
|
||||
@ -170,27 +303,113 @@ def login(body: LoginReq):
|
||||
conn.close()
|
||||
if not user or not verify_password(body.password, user["password_hash"]):
|
||||
raise HTTPException(status_code=401, detail="invalid credentials")
|
||||
token = create_token(user["id"], user["email"])
|
||||
return {"token": token, "token_type": "bearer"}
|
||||
if user["banned"]:
|
||||
raise HTTPException(status_code=403, detail="account banned")
|
||||
access = create_access_token(user["id"], user["email"], user["role"])
|
||||
refresh = create_refresh_token(user["id"])
|
||||
return {
|
||||
"access_token": access,
|
||||
"refresh_token": refresh,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||||
"user": {"id": user["id"], "email": user["email"], "role": user["role"]},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/user/bind-discord")
|
||||
def bind_discord(body: BindDiscordReq, authorization: Optional[str] = Header(default=None)):
|
||||
user = get_user_from_auth(authorization)
|
||||
@router.post("/auth/refresh")
|
||||
def refresh_token(body: RefreshReq):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE users SET discord_id = ? WHERE id = ?", (body.discord_id, user["id"]))
|
||||
conn.commit()
|
||||
updated = conn.execute("SELECT id, email, discord_id, created_at FROM users WHERE id = ?", (user["id"],)).fetchone()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
conn.close()
|
||||
return dict(updated)
|
||||
raise HTTPException(status_code=401, detail="invalid refresh token")
|
||||
if row["expires_at"] < datetime.utcnow().isoformat():
|
||||
conn.close()
|
||||
raise HTTPException(status_code=401, detail="refresh token expired")
|
||||
user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone()
|
||||
if not user or user["banned"]:
|
||||
conn.close()
|
||||
raise HTTPException(status_code=403, detail="account unavailable")
|
||||
# revoke old, issue new
|
||||
conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
access = create_access_token(user["id"], user["email"], user["role"])
|
||||
new_refresh = create_refresh_token(user["id"])
|
||||
return {
|
||||
"access_token": access,
|
||||
"refresh_token": new_refresh,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_HOURS * 3600,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/user/me")
|
||||
def me(authorization: Optional[str] = Header(default=None)):
|
||||
user = get_user_from_auth(authorization)
|
||||
@router.get("/auth/me")
|
||||
def me(user: dict = Depends(get_current_user)):
|
||||
conn = get_conn()
|
||||
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
|
||||
conn.close()
|
||||
out = {"id": user["id"], "email": user["email"], "discord_id": user["discord_id"], "created_at": user["created_at"]}
|
||||
out["subscription"] = dict(sub) if sub else {"tier": "free", "expires_at": None}
|
||||
return out
|
||||
return {
|
||||
"id": user["id"], "email": user["email"], "role": user["role"],
|
||||
"discord_id": user.get("discord_id"),
|
||||
"created_at": user["created_at"],
|
||||
"subscription": dict(sub) if sub else {"tier": "free", "expires_at": None},
|
||||
}
|
||||
|
||||
|
||||
# ─── Admin routes ────────────────────────────────────────────
|
||||
|
||||
@router.post("/admin/invite-codes")
|
||||
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
|
||||
conn = get_conn()
|
||||
codes = []
|
||||
for _ in range(body.count):
|
||||
code = secrets.token_urlsafe(6)[:8].upper()
|
||||
conn.execute(
|
||||
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)",
|
||||
(code, admin["id"], body.max_uses),
|
||||
)
|
||||
codes.append(code)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return {"codes": codes}
|
||||
|
||||
|
||||
@router.get("/admin/invite-codes")
|
||||
def list_invite_codes(admin: dict = Depends(require_admin)):
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return {"items": [dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.delete("/admin/invite-codes/{code_id}")
|
||||
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/admin/users")
|
||||
def list_users(admin: dict = Depends(require_admin)):
|
||||
conn = get_conn()
|
||||
rows = conn.execute(
|
||||
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return {"items": [dict(r) for r in rows]}
|
||||
|
||||
|
||||
@router.put("/admin/users/{user_id}/ban")
|
||||
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
|
||||
conn = get_conn()
|
||||
conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return {"ok": True, "user_id": user_id, "banned": body.banned}
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio, time, sqlite3, os
|
||||
|
||||
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
|
||||
|
||||
app = FastAPI(title="Arbitrage Engine API")
|
||||
|
||||
app.add_middleware(
|
||||
@ -13,6 +15,8 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(auth_router)
|
||||
|
||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||||
@ -101,6 +105,7 @@ async def background_snapshot_loop():
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
init_db()
|
||||
ensure_auth_tables()
|
||||
asyncio.create_task(background_snapshot_loop())
|
||||
|
||||
|
||||
@ -137,7 +142,7 @@ async def get_rates():
|
||||
|
||||
|
||||
@app.get("/api/snapshots")
|
||||
async def get_snapshots(hours: int = 24, limit: int = 5000):
|
||||
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
|
||||
"""查询本地落库的实时快照数据"""
|
||||
since = int(time.time()) - hours * 3600
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
@ -155,7 +160,7 @@ async def get_snapshots(hours: int = 24, limit: int = 5000):
|
||||
|
||||
|
||||
@app.get("/api/kline")
|
||||
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500):
|
||||
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
从 rate_snapshots 聚合K线数据
|
||||
symbol: BTC | ETH
|
||||
@ -212,7 +217,7 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500)
|
||||
|
||||
|
||||
@app.get("/api/stats/ytd")
|
||||
async def get_stats_ytd():
|
||||
async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
||||
"""今年以来(YTD)资金费率年化统计"""
|
||||
cached = get_cache("stats_ytd", 3600)
|
||||
if cached: return cached
|
||||
@ -245,7 +250,7 @@ async def get_stats_ytd():
|
||||
|
||||
|
||||
@app.get("/api/signals/history")
|
||||
async def get_signals_history(limit: int = 100):
|
||||
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
|
||||
"""查询信号推送历史"""
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
@ -261,7 +266,7 @@ async def get_signals_history(limit: int = 100):
|
||||
|
||||
|
||||
@app.get("/api/history")
|
||||
async def get_history():
|
||||
async def get_history(user: dict = Depends(get_current_user)):
|
||||
cached = get_cache("history", 60)
|
||||
if cached: return cached
|
||||
end_time = int(datetime.utcnow().timestamp() * 1000)
|
||||
@ -288,7 +293,7 @@ async def get_history():
|
||||
|
||||
|
||||
@app.get("/api/stats")
|
||||
async def get_stats():
|
||||
async def get_stats(user: dict = Depends(get_current_user)):
|
||||
cached = get_cache("stats", 60)
|
||||
if cached: return cached
|
||||
end_time = int(datetime.utcnow().timestamp() * 1000)
|
||||
|
||||
@ -2,7 +2,8 @@ import type { Metadata } from "next";
|
||||
import { Geist, Geist_Mono } from "next/font/google";
|
||||
import "./globals.css";
|
||||
import Sidebar from "@/components/Sidebar";
|
||||
import Link from "next/link";
|
||||
import { AuthProvider } from "@/lib/auth";
|
||||
import AuthHeader from "@/components/AuthHeader";
|
||||
|
||||
const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"] });
|
||||
const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"] });
|
||||
@ -16,25 +17,17 @@ export default function RootLayout({ children }: Readonly<{ children: React.Reac
|
||||
return (
|
||||
<html lang="zh">
|
||||
<body className={`${geistSans.variable} ${geistMono.variable} antialiased min-h-screen bg-slate-50 text-slate-900`}>
|
||||
<AuthProvider>
|
||||
<div className="flex min-h-screen">
|
||||
<Sidebar />
|
||||
<div className="flex-1 flex flex-col min-w-0">
|
||||
{/* 桌面端顶栏:右上角登录注册 */}
|
||||
<header className="hidden md:flex items-center justify-end px-6 py-3 bg-white border-b border-slate-200 gap-3">
|
||||
<Link href="/login"
|
||||
className="text-sm text-slate-600 hover:text-blue-600 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-blue-300 transition-colors">
|
||||
登录
|
||||
</Link>
|
||||
<Link href="/register"
|
||||
className="text-sm text-white bg-blue-600 hover:bg-blue-700 px-3 py-1.5 rounded-lg transition-colors font-medium">
|
||||
注册
|
||||
</Link>
|
||||
</header>
|
||||
<AuthHeader />
|
||||
<main className="flex-1 p-4 md:p-6 pt-16 md:pt-6">
|
||||
{children}
|
||||
</main>
|
||||
</div>
|
||||
</div>
|
||||
</AuthProvider>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
"use client";
|
||||
import { useState, Suspense } from "react";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
|
||||
function LoginForm() {
|
||||
import { useState } from "react";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Link from "next/link";
|
||||
|
||||
export default function LoginPage() {
|
||||
const { login } = useAuth();
|
||||
const router = useRouter();
|
||||
const params = useSearchParams();
|
||||
const [email, setEmail] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [error, setError] = useState("");
|
||||
@ -12,74 +15,61 @@ function LoginForm() {
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setLoading(true);
|
||||
setError("");
|
||||
setLoading(true);
|
||||
try {
|
||||
const form = new URLSearchParams();
|
||||
form.append("username", email);
|
||||
form.append("password", password);
|
||||
const r = await fetch("/api/auth/login", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: form.toString(),
|
||||
});
|
||||
const data = await r.json();
|
||||
if (!r.ok) { setError(data.detail || "登录失败"); return; }
|
||||
localStorage.setItem("arb_token", data.access_token);
|
||||
router.push("/dashboard");
|
||||
} catch {
|
||||
setError("网络错误,请重试");
|
||||
await login(email, password);
|
||||
router.push("/");
|
||||
} catch (err: any) {
|
||||
setError(err.message || "登录失败");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-[70vh] flex items-center justify-center">
|
||||
<div className="w-full max-w-md rounded-xl border border-slate-200 bg-white p-8 space-y-6">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold text-slate-900">登录</h1>
|
||||
{params.get("registered") && (
|
||||
<p className="text-emerald-400 text-sm mt-1">✅ 注册成功,请登录</p>
|
||||
)}
|
||||
<p className="text-slate-500 text-sm mt-1">登录后查看信号和账户信息</p>
|
||||
<div className="min-h-screen bg-slate-50 flex items-center justify-center p-4">
|
||||
<div className="w-full max-w-md">
|
||||
<div className="bg-white rounded-xl shadow-sm border border-slate-200 p-8">
|
||||
<div className="text-center mb-8">
|
||||
<h1 className="text-2xl font-bold text-slate-800">⚡ Arbitrage Engine</h1>
|
||||
<p className="text-slate-500 text-sm mt-2">登录您的账户</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm text-slate-700 mb-1">邮箱</label>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">邮箱</label>
|
||||
<input
|
||||
type="email" required value={email} onChange={e => setEmail(e.target.value)}
|
||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
||||
placeholder="your@email.com"
|
||||
type="email" value={email} onChange={e => setEmail(e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||
placeholder="you@example.com" required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm text-slate-700 mb-1">密码</label>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">密码</label>
|
||||
<input
|
||||
type="password" required value={password} onChange={e => setPassword(e.target.value)}
|
||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
||||
type="password" value={password} onChange={e => setPassword(e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||
placeholder="••••••••" required
|
||||
/>
|
||||
</div>
|
||||
{error && <p className="text-red-400 text-sm">{error}</p>}
|
||||
|
||||
{error && <p className="text-red-500 text-sm">{error}</p>}
|
||||
|
||||
<button
|
||||
type="submit" disabled={loading}
|
||||
className="w-full bg-cyan-600 hover:bg-cyan-500 disabled:opacity-50 text-white font-medium py-2 rounded-lg text-sm transition-colors"
|
||||
className="w-full bg-blue-600 hover:bg-blue-700 text-white py-2.5 rounded-lg font-medium text-sm transition-colors disabled:opacity-50"
|
||||
>
|
||||
{loading ? "登录中..." : "登录"}
|
||||
</button>
|
||||
</form>
|
||||
<p className="text-center text-sm text-slate-500">
|
||||
没有账号?<a href="/register" className="text-blue-600 hover:underline">注册</a>
|
||||
|
||||
<p className="text-center text-sm text-slate-500 mt-6">
|
||||
没有账户?{" "}
|
||||
<Link href="/register" className="text-blue-600 hover:underline">注册</Link>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function LoginPage() {
|
||||
return (
|
||||
<Suspense>
|
||||
<LoginForm />
|
||||
</Suspense>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -3,8 +3,10 @@
|
||||
import { useEffect, useState, useCallback, useRef } from "react";
|
||||
import { createChart, ColorType, CandlestickSeries } from "lightweight-charts";
|
||||
import { api, RatesResponse, StatsResponse, HistoryResponse, HistoryPoint, SignalHistoryItem, KBar, YtdStatsResponse } from "@/lib/api";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
import RateCard from "@/components/RateCard";
|
||||
import StatsCard from "@/components/StatsCard";
|
||||
import Link from "next/link";
|
||||
import {
|
||||
LineChart, Line, XAxis, YAxis, Tooltip, Legend,
|
||||
ResponsiveContainer, ReferenceLine
|
||||
@ -91,8 +93,35 @@ function IntervalPicker({ value, onChange }: { value: string; onChange: (v: stri
|
||||
);
|
||||
}
|
||||
|
||||
// ─── 未登录遮挡组件 ──────────────────────────────────────────────
|
||||
function AuthGate({ children }: { children: React.ReactNode }) {
|
||||
const { isLoggedIn, loading } = useAuth();
|
||||
if (loading) return <div className="text-center text-slate-400 py-12">加载中...</div>;
|
||||
if (!isLoggedIn) {
|
||||
return (
|
||||
<div className="relative">
|
||||
<div className="filter blur-sm pointer-events-none select-none opacity-60">
|
||||
{children}
|
||||
</div>
|
||||
<div className="absolute inset-0 flex items-center justify-center bg-white/70 rounded-xl">
|
||||
<div className="text-center space-y-3">
|
||||
<div className="text-4xl">🔒</div>
|
||||
<p className="text-slate-600 font-medium">登录后查看完整数据</p>
|
||||
<div className="flex gap-2 justify-center">
|
||||
<Link href="/login" className="text-sm bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-lg transition-colors">登录</Link>
|
||||
<Link href="/register" className="text-sm border border-slate-300 text-slate-600 hover:border-blue-400 px-4 py-2 rounded-lg transition-colors">注册</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
// ─── 主仪表盘 ────────────────────────────────────────────────────
|
||||
export default function Dashboard() {
|
||||
const { isLoggedIn } = useAuth();
|
||||
const [rates, setRates] = useState<RatesResponse | null>(null);
|
||||
const [stats, setStats] = useState<StatsResponse | null>(null);
|
||||
const [history, setHistory] = useState<HistoryResponse | null>(null);
|
||||
@ -110,11 +139,12 @@ export default function Dashboard() {
|
||||
}, []);
|
||||
|
||||
const fetchSlow = useCallback(async () => {
|
||||
if (!isLoggedIn) return;
|
||||
try {
|
||||
const [s, h, sig, y] = await Promise.all([api.stats(), api.history(), api.signalsHistory(), api.statsYtd()]);
|
||||
setStats(s); setHistory(h); setSignals(sig.items || []); setYtd(y);
|
||||
} catch {}
|
||||
}, []);
|
||||
}, [isLoggedIn]);
|
||||
|
||||
useEffect(() => {
|
||||
fetchRates(); fetchSlow();
|
||||
@ -154,6 +184,7 @@ export default function Dashboard() {
|
||||
</div>
|
||||
|
||||
{/* 统计卡片 */}
|
||||
<AuthGate>
|
||||
{stats && (
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
<StatsCard title="BTC 套利" mean7d={stats.BTC.mean7d} annualized={stats.BTC.annualized} accent="blue" />
|
||||
@ -283,6 +314,7 @@ export default function Dashboard() {
|
||||
<span className="text-blue-600 font-medium">策略原理:</span>
|
||||
持有现货多头 + 永续空头,每8小时收取资金费率,赚取无方向风险的稳定收益。
|
||||
</div>
|
||||
</AuthGate>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -1,80 +1,88 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Link from "next/link";
|
||||
|
||||
export default function RegisterPage() {
|
||||
const { register } = useAuth();
|
||||
const router = useRouter();
|
||||
const [email, setEmail] = useState("");
|
||||
const [password, setPassword] = useState("");
|
||||
const [discordId, setDiscordId] = useState("");
|
||||
const [inviteCode, setInviteCode] = useState("");
|
||||
const [error, setError] = useState("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setLoading(true);
|
||||
setError("");
|
||||
if (password.length < 6) {
|
||||
setError("密码至少6位");
|
||||
return;
|
||||
}
|
||||
setLoading(true);
|
||||
try {
|
||||
const r = await fetch("/api/auth/register", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ email, password, discord_id: discordId || undefined }),
|
||||
});
|
||||
const data = await r.json();
|
||||
if (!r.ok) { setError(data.detail || "注册失败"); return; }
|
||||
router.push("/login?registered=1");
|
||||
} catch {
|
||||
setError("网络错误,请重试");
|
||||
await register(email, password, inviteCode);
|
||||
router.push("/");
|
||||
} catch (err: any) {
|
||||
setError(err.message || "注册失败");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-[70vh] flex items-center justify-center">
|
||||
<div className="w-full max-w-md rounded-xl border border-slate-200 bg-white p-8 space-y-6">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold text-slate-900">注册账号</h1>
|
||||
<p className="text-slate-500 text-sm mt-1">注册后可接收套利信号推送</p>
|
||||
<div className="min-h-screen bg-slate-50 flex items-center justify-center p-4">
|
||||
<div className="w-full max-w-md">
|
||||
<div className="bg-white rounded-xl shadow-sm border border-slate-200 p-8">
|
||||
<div className="text-center mb-8">
|
||||
<h1 className="text-2xl font-bold text-slate-800">⚡ Arbitrage Engine</h1>
|
||||
<p className="text-slate-500 text-sm mt-2">注册新账户(需要邀请码)</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4">
|
||||
<div>
|
||||
<label className="block text-sm text-slate-700 mb-1">邮箱</label>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">邀请码</label>
|
||||
<input
|
||||
type="email" required value={email} onChange={e => setEmail(e.target.value)}
|
||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
||||
placeholder="your@email.com"
|
||||
type="text" value={inviteCode} onChange={e => setInviteCode(e.target.value.toUpperCase())}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent font-mono tracking-wider"
|
||||
placeholder="输入邀请码" required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm text-slate-700 mb-1">密码</label>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">邮箱</label>
|
||||
<input
|
||||
type="password" required value={password} onChange={e => setPassword(e.target.value)}
|
||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
||||
placeholder="至少8位"
|
||||
minLength={8}
|
||||
type="email" value={email} onChange={e => setEmail(e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||
placeholder="you@example.com" required
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-sm text-slate-700 mb-1">Discord ID <span className="text-slate-500">(选填,用于接收信号)</span></label>
|
||||
<label className="block text-sm font-medium text-slate-700 mb-1">密码</label>
|
||||
<input
|
||||
type="text" value={discordId} onChange={e => setDiscordId(e.target.value)}
|
||||
className="w-full bg-white border border-slate-200 rounded-lg px-3 py-2 text-slate-900 text-sm focus:outline-none focus:border-cyan-500"
|
||||
placeholder="例:123456789012345678"
|
||||
type="password" value={password} onChange={e => setPassword(e.target.value)}
|
||||
className="w-full px-3 py-2 border border-slate-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent"
|
||||
placeholder="至少6位" required minLength={6}
|
||||
/>
|
||||
</div>
|
||||
{error && <p className="text-red-400 text-sm">{error}</p>}
|
||||
|
||||
{error && <p className="text-red-500 text-sm">{error}</p>}
|
||||
|
||||
<button
|
||||
type="submit" disabled={loading}
|
||||
className="w-full bg-cyan-600 hover:bg-cyan-500 disabled:opacity-50 text-white font-medium py-2 rounded-lg text-sm transition-colors"
|
||||
className="w-full bg-blue-600 hover:bg-blue-700 text-white py-2.5 rounded-lg font-medium text-sm transition-colors disabled:opacity-50"
|
||||
>
|
||||
{loading ? "注册中..." : "注册"}
|
||||
</button>
|
||||
</form>
|
||||
<p className="text-center text-sm text-slate-500">
|
||||
已有账号?<a href="/login" className="text-blue-600 hover:underline">登录</a>
|
||||
|
||||
<p className="text-center text-sm text-slate-500 mt-6">
|
||||
已有账户?{" "}
|
||||
<Link href="/login" className="text-blue-600 hover:underline">登录</Link>
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
54
frontend/components/AuthHeader.tsx
Normal file
54
frontend/components/AuthHeader.tsx
Normal file
@ -0,0 +1,54 @@
|
||||
"use client";
|
||||
|
||||
import Link from "next/link";
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
export default function AuthHeader() {
|
||||
const { user, isLoggedIn, logout } = useAuth();
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* 桌面端顶栏 */}
|
||||
<header className="hidden md:flex items-center justify-end px-6 py-3 bg-white border-b border-slate-200 gap-3">
|
||||
{isLoggedIn ? (
|
||||
<>
|
||||
<span className="text-sm text-slate-500">{user?.email}</span>
|
||||
{user?.role === "admin" && (
|
||||
<span className="bg-amber-100 text-amber-700 text-xs px-2 py-0.5 rounded-full font-medium">Admin</span>
|
||||
)}
|
||||
<button onClick={logout}
|
||||
className="text-sm text-slate-600 hover:text-red-500 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-red-300 transition-colors">
|
||||
退出
|
||||
</button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Link href="/login"
|
||||
className="text-sm text-slate-600 hover:text-blue-600 px-3 py-1.5 rounded-lg border border-slate-200 hover:border-blue-300 transition-colors">
|
||||
登录
|
||||
</Link>
|
||||
<Link href="/register"
|
||||
className="text-sm text-white bg-blue-600 hover:bg-blue-700 px-3 py-1.5 rounded-lg transition-colors font-medium">
|
||||
注册
|
||||
</Link>
|
||||
</>
|
||||
)}
|
||||
</header>
|
||||
|
||||
{/* 手机端顶栏 */}
|
||||
<header className="md:hidden flex items-center justify-between px-4 py-2 bg-white border-b border-slate-200 fixed top-0 left-0 right-0 z-40">
|
||||
<span className="font-bold text-blue-600 text-sm">⚡ ArbEngine</span>
|
||||
<div className="flex items-center gap-2">
|
||||
{isLoggedIn ? (
|
||||
<button onClick={logout} className="text-xs text-slate-500 hover:text-red-500">退出</button>
|
||||
) : (
|
||||
<>
|
||||
<Link href="/login" className="text-xs text-slate-500">登录</Link>
|
||||
<Link href="/register" className="text-xs bg-blue-600 text-white px-2 py-1 rounded">注册</Link>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</header>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@ -1,17 +1,11 @@
|
||||
"use client";
|
||||
import Link from "next/link";
|
||||
import { useState } from "react";
|
||||
|
||||
const navLinks = [
|
||||
{ href: "/", label: "仪表盘" },
|
||||
{ href: "/kline", label: "K线" },
|
||||
{ href: "/live", label: "实时" },
|
||||
{ href: "/signals", label: "信号" },
|
||||
{ href: "/about", label: "说明" },
|
||||
];
|
||||
import { useAuth } from "@/lib/auth";
|
||||
|
||||
export default function Navbar() {
|
||||
const [open, setOpen] = useState(false);
|
||||
const { user, isLoggedIn, logout } = useAuth();
|
||||
|
||||
return (
|
||||
<nav className="border-b border-slate-200 bg-white sticky top-0 z-50 shadow-sm">
|
||||
@ -22,21 +16,36 @@ export default function Navbar() {
|
||||
|
||||
{/* Desktop nav */}
|
||||
<div className="hidden md:flex items-center gap-6 text-sm ml-8">
|
||||
{navLinks.map(l => (
|
||||
<Link key={l.href} href={l.href} className="text-slate-600 hover:text-blue-600 transition-colors">
|
||||
{l.label}
|
||||
</Link>
|
||||
))}
|
||||
<Link href="/" className="text-slate-600 hover:text-blue-600 transition-colors">仪表盘</Link>
|
||||
<Link href="/about" className="text-slate-600 hover:text-blue-600 transition-colors">说明</Link>
|
||||
</div>
|
||||
<div className="hidden md:flex items-center gap-3 text-sm ml-auto">
|
||||
{isLoggedIn ? (
|
||||
<>
|
||||
<span className="text-slate-500 text-xs">{user?.email}</span>
|
||||
{user?.role === "admin" && (
|
||||
<span className="bg-amber-100 text-amber-700 text-xs px-2 py-0.5 rounded-full font-medium">Admin</span>
|
||||
)}
|
||||
<button onClick={logout} className="text-slate-500 hover:text-red-500 transition-colors">退出</button>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Link href="/login" className="text-slate-500 hover:text-slate-800 transition-colors">登录</Link>
|
||||
<Link href="/register" className="bg-blue-600 hover:bg-blue-700 text-white px-3 py-1.5 rounded-lg transition-colors">注册</Link>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Mobile */}
|
||||
<div className="md:hidden ml-auto flex items-center gap-2">
|
||||
{isLoggedIn ? (
|
||||
<button onClick={logout} className="text-slate-500 text-sm">退出</button>
|
||||
) : (
|
||||
<>
|
||||
<Link href="/login" className="text-slate-500 text-sm">登录</Link>
|
||||
<Link href="/register" className="bg-blue-600 text-white px-2 py-1 rounded text-sm">注册</Link>
|
||||
</>
|
||||
)}
|
||||
<button onClick={() => setOpen(!open)} className="ml-1 p-2 text-slate-500" aria-label="菜单">
|
||||
{open ? (
|
||||
<svg className="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
@ -53,12 +62,8 @@ export default function Navbar() {
|
||||
|
||||
{open && (
|
||||
<div className="md:hidden border-t border-slate-100 bg-white px-4 py-3 space-y-1">
|
||||
{navLinks.map(l => (
|
||||
<Link key={l.href} href={l.href} onClick={() => setOpen(false)}
|
||||
className="block py-2 text-slate-600 hover:text-blue-600 transition-colors text-sm">
|
||||
{l.label}
|
||||
</Link>
|
||||
))}
|
||||
<Link href="/" onClick={() => setOpen(false)} className="block py-2 text-slate-600 hover:text-blue-600 text-sm">仪表盘</Link>
|
||||
<Link href="/about" onClick={() => setOpen(false)} className="block py-2 text-slate-600 hover:text-blue-600 text-sm">说明</Link>
|
||||
</div>
|
||||
)}
|
||||
</nav>
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { authFetch } from "./auth";
|
||||
|
||||
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
|
||||
|
||||
export interface RateData {
|
||||
@ -84,21 +86,31 @@ export interface YtdStatsResponse {
|
||||
ETH: YtdStats;
|
||||
}
|
||||
|
||||
async function fetchAPI<T>(path: string): Promise<T> {
|
||||
// Public fetch (no auth needed)
|
||||
async function fetchPublic<T>(path: string): Promise<T> {
|
||||
const res = await fetch(`${API_BASE}${path}`, { cache: "no-store" });
|
||||
if (!res.ok) throw new Error(`API error ${res.status}`);
|
||||
return res.json();
|
||||
}
|
||||
|
||||
// Protected fetch (auth required, auto-refresh)
|
||||
async function fetchProtected<T>(path: string): Promise<T> {
|
||||
const res = await authFetch(path, { cache: "no-store" });
|
||||
if (!res.ok) throw new Error(`API error ${res.status}`);
|
||||
return res.json();
|
||||
}
|
||||
|
||||
export const api = {
|
||||
rates: () => fetchAPI<RatesResponse>("/api/rates"),
|
||||
history: () => fetchAPI<HistoryResponse>("/api/history"),
|
||||
stats: () => fetchAPI<StatsResponse>("/api/stats"),
|
||||
health: () => fetchAPI<{ status: string }>("/api/health"),
|
||||
signalsHistory: () => fetchAPI<SignalsHistoryResponse>("/api/signals/history"),
|
||||
// Public
|
||||
rates: () => fetchPublic<RatesResponse>("/api/rates"),
|
||||
health: () => fetchPublic<{ status: string }>("/api/health"),
|
||||
// Protected
|
||||
history: () => fetchProtected<HistoryResponse>("/api/history"),
|
||||
stats: () => fetchProtected<StatsResponse>("/api/stats"),
|
||||
signalsHistory: () => fetchProtected<SignalsHistoryResponse>("/api/signals/history"),
|
||||
snapshots: (hours = 24, limit = 5000) =>
|
||||
fetchAPI<SnapshotsResponse>(`/api/snapshots?hours=${hours}&limit=${limit}`),
|
||||
fetchProtected<SnapshotsResponse>(`/api/snapshots?hours=${hours}&limit=${limit}`),
|
||||
kline: (symbol = "BTC", interval = "1h", limit = 500) =>
|
||||
fetchAPI<KlineResponse>(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`),
|
||||
statsYtd: () => fetchAPI<YtdStatsResponse>("/api/stats/ytd"),
|
||||
fetchProtected<KlineResponse>(`/api/kline?symbol=${symbol}&interval=${interval}&limit=${limit}`),
|
||||
statsYtd: () => fetchProtected<YtdStatsResponse>("/api/stats/ytd"),
|
||||
};
|
||||
|
||||
137
frontend/lib/auth.tsx
Normal file
137
frontend/lib/auth.tsx
Normal file
@ -0,0 +1,137 @@
|
||||
"use client";
|
||||
|
||||
import { createContext, useContext, useState, useEffect, ReactNode, useCallback } from "react";
|
||||
|
||||
interface User {
|
||||
id: number;
|
||||
email: string;
|
||||
role: string;
|
||||
}
|
||||
|
||||
interface AuthState {
|
||||
user: User | null;
|
||||
accessToken: string | null;
|
||||
loading: boolean;
|
||||
login: (email: string, password: string) => Promise<void>;
|
||||
register: (email: string, password: string, inviteCode: string) => Promise<void>;
|
||||
logout: () => void;
|
||||
isLoggedIn: boolean;
|
||||
isAdmin: boolean;
|
||||
}
|
||||
|
||||
const AuthContext = createContext<AuthState | undefined>(undefined);
|
||||
|
||||
const API_BASE = process.env.NEXT_PUBLIC_API_URL ?? "";
|
||||
|
||||
export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
const [user, setUser] = useState<User | null>(null);
|
||||
const [accessToken, setAccessToken] = useState<string | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
// init from localStorage
|
||||
useEffect(() => {
|
||||
const token = localStorage.getItem("access_token");
|
||||
const saved = localStorage.getItem("user");
|
||||
if (token && saved) {
|
||||
try {
|
||||
setAccessToken(token);
|
||||
setUser(JSON.parse(saved));
|
||||
} catch {}
|
||||
}
|
||||
setLoading(false);
|
||||
}, []);
|
||||
|
||||
const saveAuth = (data: { access_token: string; refresh_token: string; user: User }) => {
|
||||
localStorage.setItem("access_token", data.access_token);
|
||||
localStorage.setItem("refresh_token", data.refresh_token);
|
||||
localStorage.setItem("user", JSON.stringify(data.user));
|
||||
setAccessToken(data.access_token);
|
||||
setUser(data.user);
|
||||
};
|
||||
|
||||
const login = useCallback(async (email: string, password: string) => {
|
||||
const res = await fetch(`${API_BASE}/api/auth/login`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ email, password }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
throw new Error(err.detail || "Login failed");
|
||||
}
|
||||
const data = await res.json();
|
||||
saveAuth(data);
|
||||
}, []);
|
||||
|
||||
const register = useCallback(async (email: string, password: string, inviteCode: string) => {
|
||||
const res = await fetch(`${API_BASE}/api/auth/register`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ email, password, invite_code: inviteCode }),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
throw new Error(err.detail || "Registration failed");
|
||||
}
|
||||
const data = await res.json();
|
||||
saveAuth(data);
|
||||
}, []);
|
||||
|
||||
const logout = useCallback(() => {
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
localStorage.removeItem("user");
|
||||
setAccessToken(null);
|
||||
setUser(null);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<AuthContext.Provider value={{
|
||||
user, accessToken, loading, login, register, logout,
|
||||
isLoggedIn: !!user, isAdmin: user?.role === "admin",
|
||||
}}>
|
||||
{children}
|
||||
</AuthContext.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
export function useAuth() {
|
||||
const ctx = useContext(AuthContext);
|
||||
if (!ctx) throw new Error("useAuth must be used within AuthProvider");
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// Authenticated fetch helper
|
||||
export async function authFetch(path: string, options: RequestInit = {}): Promise<Response> {
|
||||
const token = localStorage.getItem("access_token");
|
||||
const headers = new Headers(options.headers);
|
||||
if (token) headers.set("Authorization", `Bearer ${token}`);
|
||||
|
||||
let res = await fetch(`${API_BASE}${path}`, { ...options, headers });
|
||||
|
||||
// try refresh on 401
|
||||
if (res.status === 401) {
|
||||
const refreshToken = localStorage.getItem("refresh_token");
|
||||
if (refreshToken) {
|
||||
const refreshRes = await fetch(`${API_BASE}/api/auth/refresh`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ refresh_token: refreshToken }),
|
||||
});
|
||||
if (refreshRes.ok) {
|
||||
const data = await refreshRes.json();
|
||||
localStorage.setItem("access_token", data.access_token);
|
||||
localStorage.setItem("refresh_token", data.refresh_token);
|
||||
headers.set("Authorization", `Bearer ${data.access_token}`);
|
||||
res = await fetch(`${API_BASE}${path}`, { ...options, headers });
|
||||
} else {
|
||||
// refresh failed, clear auth
|
||||
localStorage.removeItem("access_token");
|
||||
localStorage.removeItem("refresh_token");
|
||||
localStorage.removeItem("user");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user